fastMONAI 0.5.4__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +53 -1
- fastMONAI/dataset_info.py +144 -7
- fastMONAI/utils.py +296 -7
- fastMONAI/vision_augmentation.py +137 -5
- fastMONAI/vision_patch.py +157 -23
- fastMONAI/vision_plot.py +89 -1
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.0.dist-info}/METADATA +1 -1
- fastmonai-0.6.0.dist-info/RECORD +21 -0
- fastmonai-0.5.4.dist-info/RECORD +0 -21
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.0.dist-info}/WHEEL +0 -0
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.0.dist-info}/entry_points.txt +0 -0
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {fastmonai-0.5.4.dist-info → fastmonai-0.6.0.dist-info}/top_level.txt +0 -0
fastMONAI/utils.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
# %% auto 0
|
|
4
4
|
__all__ = ['store_variables', 'load_variables', 'store_patch_variables', 'load_patch_variables', 'print_colab_gpu_info',
|
|
5
|
-
'ModelTrackingCallback', 'MLflowUIManager']
|
|
5
|
+
'ModelTrackingCallback', 'create_mlflow_callback', 'MLflowUIManager']
|
|
6
6
|
|
|
7
7
|
# %% ../nbs/07_utils.ipynb 1
|
|
8
8
|
import pickle
|
|
@@ -13,6 +13,7 @@ import mlflow.pytorch
|
|
|
13
13
|
import os
|
|
14
14
|
import tempfile
|
|
15
15
|
import json
|
|
16
|
+
from datetime import datetime
|
|
16
17
|
from fastai.callback.core import Callback
|
|
17
18
|
from fastcore.foundation import L
|
|
18
19
|
from typing import Any
|
|
@@ -136,12 +137,147 @@ def print_colab_gpu_info():
|
|
|
136
137
|
else: print(colab_gpu_msg)
|
|
137
138
|
|
|
138
139
|
# %% ../nbs/07_utils.ipynb 9
|
|
140
|
+
def _detect_patch_workflow(dls) -> bool:
|
|
141
|
+
"""Detect if DataLoaders are patch-based (MedPatchDataLoaders).
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
dls: DataLoaders instance
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
True if dls is a MedPatchDataLoaders instance
|
|
148
|
+
"""
|
|
149
|
+
return hasattr(dls, 'patch_config') or hasattr(dls, '_patch_config')
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _extract_size_from_transforms(tfms) -> list | None:
|
|
153
|
+
"""Extract target size from PadOrCrop transform if present.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
tfms: List of transforms
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Target size as list, or None if not found
|
|
160
|
+
"""
|
|
161
|
+
if tfms is None:
|
|
162
|
+
return None
|
|
163
|
+
for tfm in tfms:
|
|
164
|
+
if hasattr(tfm, 'pad_or_crop') and hasattr(tfm.pad_or_crop, 'target_shape'):
|
|
165
|
+
return list(tfm.pad_or_crop.target_shape)
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _extract_standard_config(learn) -> dict:
|
|
170
|
+
"""Extract config from standard MedDataBlock workflow.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
learn: fastai Learner instance
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Dictionary with extracted configuration
|
|
177
|
+
"""
|
|
178
|
+
from fastMONAI.vision_core import MedBase
|
|
179
|
+
dls = learn.dls
|
|
180
|
+
|
|
181
|
+
# Get preprocessing from MedBase class attributes
|
|
182
|
+
apply_reorder = MedBase.apply_reorder
|
|
183
|
+
target_spacing = MedBase.target_spacing
|
|
184
|
+
|
|
185
|
+
# Extract item_tfms from DataLoaders pipeline
|
|
186
|
+
item_tfms = []
|
|
187
|
+
if hasattr(dls, 'after_item') and dls.after_item:
|
|
188
|
+
item_tfms = list(dls.after_item.fs)
|
|
189
|
+
|
|
190
|
+
# Extract size from PadOrCrop transform
|
|
191
|
+
size = _extract_size_from_transforms(item_tfms)
|
|
192
|
+
|
|
193
|
+
return {
|
|
194
|
+
'apply_reorder': apply_reorder,
|
|
195
|
+
'target_spacing': target_spacing,
|
|
196
|
+
'size': size,
|
|
197
|
+
'item_tfms': item_tfms,
|
|
198
|
+
'batch_size': dls.bs,
|
|
199
|
+
'patch_config': None,
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _extract_patch_config(learn) -> dict:
|
|
204
|
+
"""Extract config from MedPatchDataLoaders workflow.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
learn: fastai Learner instance
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Dictionary with extracted configuration including patch-specific params
|
|
211
|
+
"""
|
|
212
|
+
dls = learn.dls
|
|
213
|
+
patch_config = getattr(dls, '_patch_config', None) or getattr(dls, 'patch_config', None)
|
|
214
|
+
|
|
215
|
+
config = {
|
|
216
|
+
'apply_reorder': getattr(dls, '_apply_reorder', patch_config.apply_reorder if patch_config else False),
|
|
217
|
+
'target_spacing': getattr(dls, '_target_spacing', patch_config.target_spacing if patch_config else None),
|
|
218
|
+
'size': patch_config.patch_size if patch_config else None,
|
|
219
|
+
'item_tfms': getattr(dls, '_pre_patch_tfms', []) or [],
|
|
220
|
+
'batch_size': dls.bs,
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
# Add patch-specific params for logging
|
|
224
|
+
if patch_config:
|
|
225
|
+
config['patch_config'] = {
|
|
226
|
+
'patch_size': patch_config.patch_size,
|
|
227
|
+
'patch_overlap': patch_config.patch_overlap,
|
|
228
|
+
'samples_per_volume': patch_config.samples_per_volume,
|
|
229
|
+
'sampler_type': patch_config.sampler_type,
|
|
230
|
+
'label_probabilities': str(patch_config.label_probabilities) if patch_config.label_probabilities else None,
|
|
231
|
+
'queue_length': patch_config.queue_length,
|
|
232
|
+
'aggregation_mode': patch_config.aggregation_mode,
|
|
233
|
+
'padding_mode': patch_config.padding_mode,
|
|
234
|
+
'keep_largest_component': patch_config.keep_largest_component,
|
|
235
|
+
}
|
|
236
|
+
else:
|
|
237
|
+
config['patch_config'] = None
|
|
238
|
+
|
|
239
|
+
return config
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _extract_loss_name(learn) -> str:
|
|
243
|
+
"""Extract loss function name from Learner.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
learn: fastai Learner instance
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Name of the loss function
|
|
250
|
+
"""
|
|
251
|
+
loss_func = learn.loss_func
|
|
252
|
+
# Handle CustomLoss wrapper
|
|
253
|
+
if hasattr(loss_func, 'loss_func'):
|
|
254
|
+
inner = loss_func.loss_func
|
|
255
|
+
return inner._get_name() if hasattr(inner, '_get_name') else inner.__class__.__name__
|
|
256
|
+
return loss_func._get_name() if hasattr(loss_func, '_get_name') else loss_func.__class__.__name__
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _extract_model_name(learn) -> str:
|
|
260
|
+
"""Extract model architecture name from Learner.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
learn: fastai Learner instance
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Name of the model architecture
|
|
267
|
+
"""
|
|
268
|
+
model = learn.model
|
|
269
|
+
return model._get_name() if hasattr(model, '_get_name') else model.__class__.__name__
|
|
270
|
+
|
|
271
|
+
# %% ../nbs/07_utils.ipynb 10
|
|
139
272
|
class ModelTrackingCallback(Callback):
|
|
140
273
|
"""
|
|
141
274
|
A FastAI callback for comprehensive MLflow experiment tracking.
|
|
142
275
|
|
|
143
276
|
This callback automatically logs hyperparameters, metrics, model artifacts,
|
|
144
|
-
and configuration to MLflow during training.
|
|
277
|
+
and configuration to MLflow during training. If a SaveModelCallback is present,
|
|
278
|
+
the best model checkpoint will also be logged as an artifact.
|
|
279
|
+
|
|
280
|
+
Supports auto-managed runs when created via `create_mlflow_callback()`.
|
|
145
281
|
"""
|
|
146
282
|
|
|
147
283
|
def __init__(
|
|
@@ -151,7 +287,13 @@ class ModelTrackingCallback(Callback):
|
|
|
151
287
|
item_tfms: list[Any],
|
|
152
288
|
size: list[int],
|
|
153
289
|
target_spacing: list[float],
|
|
154
|
-
apply_reorder: bool
|
|
290
|
+
apply_reorder: bool,
|
|
291
|
+
experiment_name: str = None,
|
|
292
|
+
run_name: str = None,
|
|
293
|
+
auto_start: bool = False,
|
|
294
|
+
patch_config: dict = None,
|
|
295
|
+
extra_params: dict = None,
|
|
296
|
+
extra_tags: dict = None
|
|
155
297
|
):
|
|
156
298
|
"""
|
|
157
299
|
Initialize the MLflow tracking callback.
|
|
@@ -159,9 +301,16 @@ class ModelTrackingCallback(Callback):
|
|
|
159
301
|
Args:
|
|
160
302
|
model_name: Name of the model architecture for registration
|
|
161
303
|
loss_function: Name of the loss function being used
|
|
304
|
+
item_tfms: List of item transforms
|
|
162
305
|
size: Model input dimensions
|
|
163
306
|
target_spacing: Resampling dimensions
|
|
164
307
|
apply_reorder: Whether reordering augmentation is applied
|
|
308
|
+
experiment_name: MLflow experiment name (used with auto_start)
|
|
309
|
+
run_name: MLflow run name (auto-generated if None)
|
|
310
|
+
auto_start: If True, auto-starts/stops MLflow run
|
|
311
|
+
patch_config: Patch configuration dict for logging (from MedPatchDataLoaders)
|
|
312
|
+
extra_params: Additional parameters to log
|
|
313
|
+
extra_tags: MLflow tags to set on the run
|
|
165
314
|
"""
|
|
166
315
|
self.model_name = model_name
|
|
167
316
|
self.loss_function = loss_function
|
|
@@ -170,6 +319,15 @@ class ModelTrackingCallback(Callback):
|
|
|
170
319
|
self.target_spacing = target_spacing
|
|
171
320
|
self.apply_reorder = apply_reorder
|
|
172
321
|
|
|
322
|
+
# New auto-management fields
|
|
323
|
+
self.experiment_name = experiment_name
|
|
324
|
+
self.run_name = run_name
|
|
325
|
+
self.auto_start = auto_start
|
|
326
|
+
self.patch_config = patch_config
|
|
327
|
+
self.extra_params = extra_params or {}
|
|
328
|
+
self.extra_tags = extra_tags or {}
|
|
329
|
+
self._auto_started = False
|
|
330
|
+
|
|
173
331
|
self.config = self._build_config()
|
|
174
332
|
|
|
175
333
|
def extract_all_params(self, tfm):
|
|
@@ -235,6 +393,12 @@ class ModelTrackingCallback(Callback):
|
|
|
235
393
|
separators=(',', ': ')
|
|
236
394
|
)
|
|
237
395
|
|
|
396
|
+
# Add patch-specific params if present
|
|
397
|
+
if self.patch_config:
|
|
398
|
+
for key, value in self.patch_config.items():
|
|
399
|
+
if value is not None:
|
|
400
|
+
params[f"patch_{key}"] = value if isinstance(value, (int, float, str, bool)) else str(value)
|
|
401
|
+
|
|
238
402
|
return params
|
|
239
403
|
|
|
240
404
|
def _extract_epoch_metrics(self) -> dict[str, float]:
|
|
@@ -284,6 +448,23 @@ class ModelTrackingCallback(Callback):
|
|
|
284
448
|
if os.path.exists(weights_file):
|
|
285
449
|
mlflow.log_artifact(weights_file, "model")
|
|
286
450
|
|
|
451
|
+
# Auto-detect SaveModelCallback and log best model
|
|
452
|
+
from fastai.callback.tracker import SaveModelCallback
|
|
453
|
+
best_model_cb = None
|
|
454
|
+
for cb in self.learn.cbs:
|
|
455
|
+
if isinstance(cb, SaveModelCallback):
|
|
456
|
+
best_model_path = self.learn.path/self.learn.model_dir/f'{cb.fname}.pth'
|
|
457
|
+
if best_model_path.exists():
|
|
458
|
+
mlflow.log_artifact(str(best_model_path), "model")
|
|
459
|
+
print(f"Logged best model artifact: {cb.fname}.pth")
|
|
460
|
+
best_model_cb = cb
|
|
461
|
+
break
|
|
462
|
+
|
|
463
|
+
# Load best model weights before export if SaveModelCallback was used
|
|
464
|
+
if best_model_cb is not None:
|
|
465
|
+
self.learn.load(best_model_cb.fname)
|
|
466
|
+
print(f"Loaded best model weights ({best_model_cb.fname}) for learner export")
|
|
467
|
+
|
|
287
468
|
# Remove MLflow callbacks before exporting learner for inference
|
|
288
469
|
# This prevents the callback from being triggered during inference
|
|
289
470
|
original_cbs = self.learn.cbs.copy() # Save original callbacks
|
|
@@ -312,8 +493,22 @@ class ModelTrackingCallback(Callback):
|
|
|
312
493
|
)
|
|
313
494
|
|
|
314
495
|
def before_fit(self) -> None:
|
|
315
|
-
"""Log hyperparameters before training starts."""
|
|
496
|
+
"""Log hyperparameters before training starts. Auto-start run if configured."""
|
|
497
|
+
# Auto-start run if requested
|
|
498
|
+
if self.auto_start:
|
|
499
|
+
if self.experiment_name:
|
|
500
|
+
mlflow.set_experiment(self.experiment_name)
|
|
501
|
+
self.run_name = self.run_name or f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
502
|
+
mlflow.start_run(run_name=self.run_name)
|
|
503
|
+
self._auto_started = True
|
|
504
|
+
|
|
505
|
+
# Set tags
|
|
506
|
+
if self.extra_tags:
|
|
507
|
+
mlflow.set_tags(self.extra_tags)
|
|
508
|
+
|
|
509
|
+
# Log params
|
|
316
510
|
params = self._extract_training_params()
|
|
511
|
+
params.update(self.extra_params)
|
|
317
512
|
mlflow.log_params(params)
|
|
318
513
|
|
|
319
514
|
def after_epoch(self) -> None:
|
|
@@ -332,10 +527,104 @@ class ModelTrackingCallback(Callback):
|
|
|
332
527
|
self._save_model_artifacts(temp_path)
|
|
333
528
|
|
|
334
529
|
self._register_pytorch_model()
|
|
335
|
-
|
|
336
|
-
|
|
530
|
+
|
|
531
|
+
run_id = mlflow.active_run().info.run_id
|
|
532
|
+
print(f"MLflow run completed. Run ID: {run_id}")
|
|
533
|
+
|
|
534
|
+
# End run if auto-started
|
|
535
|
+
if self._auto_started:
|
|
536
|
+
mlflow.end_run()
|
|
537
|
+
self._auto_started = False
|
|
337
538
|
|
|
338
|
-
# %% ../nbs/07_utils.ipynb
|
|
539
|
+
# %% ../nbs/07_utils.ipynb 11
|
|
540
|
+
def create_mlflow_callback(
|
|
541
|
+
learn,
|
|
542
|
+
experiment_name: str = None,
|
|
543
|
+
run_name: str = None,
|
|
544
|
+
auto_start: bool = True,
|
|
545
|
+
model_name: str = None,
|
|
546
|
+
extra_params: dict = None,
|
|
547
|
+
extra_tags: dict = None,
|
|
548
|
+
) -> ModelTrackingCallback:
|
|
549
|
+
"""Create MLflow tracking callback with auto-extracted configuration.
|
|
550
|
+
|
|
551
|
+
This factory function automatically extracts configuration from the Learner,
|
|
552
|
+
eliminating the need to manually specify parameters like size, transforms,
|
|
553
|
+
loss function, etc.
|
|
554
|
+
|
|
555
|
+
Auto-extracts from Learner:
|
|
556
|
+
- Preprocessing: apply_reorder, target_spacing, size/patch_size
|
|
557
|
+
- Transforms: item_tfms or pre_patch_tfms
|
|
558
|
+
- Training: loss_func, model architecture
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
learn: fastai Learner instance
|
|
562
|
+
experiment_name: MLflow experiment name. If None, uses model name.
|
|
563
|
+
run_name: MLflow run name. If None, auto-generates with timestamp.
|
|
564
|
+
auto_start: If True, auto-starts/stops MLflow run in before_fit/after_fit.
|
|
565
|
+
model_name: Override auto-extracted model name for registration.
|
|
566
|
+
extra_params: Additional parameters to log (e.g., {'dropout': 0.5}).
|
|
567
|
+
extra_tags: MLflow tags to set on the run.
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
ModelTrackingCallback ready to use with learn.fit()
|
|
571
|
+
|
|
572
|
+
Example:
|
|
573
|
+
>>> # Instead of this (6 manual params):
|
|
574
|
+
>>> # mlflow_callback = ModelTrackingCallback(
|
|
575
|
+
>>> # model_name=f"{task}_{model._get_name()}",
|
|
576
|
+
>>> # loss_function=loss_func.loss_func._get_name(),
|
|
577
|
+
>>> # item_tfms=item_tfms,
|
|
578
|
+
>>> # size=size,
|
|
579
|
+
>>> # target_spacing=target_spacing,
|
|
580
|
+
>>> # apply_reorder=True,
|
|
581
|
+
>>> # )
|
|
582
|
+
>>> # with mlflow.start_run(run_name="training"):
|
|
583
|
+
>>> # learn.fit_one_cycle(30, lr, cbs=[mlflow_callback])
|
|
584
|
+
>>>
|
|
585
|
+
>>> # Do this (zero manual params):
|
|
586
|
+
>>> callback = create_mlflow_callback(learn, experiment_name="Task02_Heart")
|
|
587
|
+
>>> learn.fit_one_cycle(30, lr, cbs=[callback, save_best])
|
|
588
|
+
"""
|
|
589
|
+
# Detect workflow and extract config
|
|
590
|
+
if _detect_patch_workflow(learn.dls):
|
|
591
|
+
config = _extract_patch_config(learn)
|
|
592
|
+
else:
|
|
593
|
+
config = _extract_standard_config(learn)
|
|
594
|
+
|
|
595
|
+
# Extract model/loss info
|
|
596
|
+
_model_name = model_name or _extract_model_name(learn)
|
|
597
|
+
_loss_name = _extract_loss_name(learn)
|
|
598
|
+
|
|
599
|
+
# Set experiment name
|
|
600
|
+
if experiment_name is None:
|
|
601
|
+
experiment_name = _model_name
|
|
602
|
+
|
|
603
|
+
# Validate size was extracted
|
|
604
|
+
if config['size'] is None:
|
|
605
|
+
raise ValueError(
|
|
606
|
+
"Could not auto-extract 'size'. Either:\n"
|
|
607
|
+
"1. Add PadOrCrop to item_tfms, or\n"
|
|
608
|
+
"2. Use MedPatchDataLoaders with PatchConfig, or\n"
|
|
609
|
+
"3. Use ModelTrackingCallback directly with manual params"
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
return ModelTrackingCallback(
|
|
613
|
+
model_name=_model_name,
|
|
614
|
+
loss_function=_loss_name,
|
|
615
|
+
item_tfms=config['item_tfms'],
|
|
616
|
+
size=config['size'],
|
|
617
|
+
target_spacing=config['target_spacing'],
|
|
618
|
+
apply_reorder=config['apply_reorder'],
|
|
619
|
+
experiment_name=experiment_name,
|
|
620
|
+
run_name=run_name,
|
|
621
|
+
auto_start=auto_start,
|
|
622
|
+
patch_config=config['patch_config'],
|
|
623
|
+
extra_params=extra_params,
|
|
624
|
+
extra_tags=extra_tags,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
# %% ../nbs/07_utils.ipynb 13
|
|
339
628
|
import subprocess
|
|
340
629
|
import threading
|
|
341
630
|
import time
|
fastMONAI/vision_augmentation.py
CHANGED
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
# %% auto 0
|
|
4
4
|
__all__ = ['CustomDictTransform', 'do_pad_or_crop', 'PadOrCrop', 'ZNormalization', 'RescaleIntensity', 'NormalizeIntensity',
|
|
5
5
|
'BraTSMaskConverter', 'BinaryConverter', 'RandomGhosting', 'RandomSpike', 'RandomNoise', 'RandomBiasField',
|
|
6
|
-
'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', '
|
|
7
|
-
'RandomAffine', 'RandomFlip', 'OneOf']
|
|
6
|
+
'RandomBlur', 'RandomGamma', 'RandomIntensityScale', 'RandomMotion', 'RandomCutout',
|
|
7
|
+
'RandomElasticDeformation', 'RandomAffine', 'RandomFlip', 'OneOf']
|
|
8
8
|
|
|
9
9
|
# %% ../nbs/03_vision_augment.ipynb 2
|
|
10
10
|
from fastai.data.all import *
|
|
@@ -425,7 +425,139 @@ class RandomMotion(DisplayedTransform):
|
|
|
425
425
|
def encodes(self, o: MedMask):
|
|
426
426
|
return o
|
|
427
427
|
|
|
428
|
+
# %% ../nbs/03_vision_augment.ipynb 22
|
|
429
|
+
def _create_ellipsoid_mask(shape, center, radii):
|
|
430
|
+
"""Create a 3D ellipsoid mask.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
shape: (D, H, W) shape of the volume
|
|
434
|
+
center: (z, y, x) center of ellipsoid
|
|
435
|
+
radii: (rz, ry, rx) radii along each axis
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
Boolean mask where True = inside ellipsoid
|
|
439
|
+
"""
|
|
440
|
+
z, y, x = torch.meshgrid(
|
|
441
|
+
torch.arange(shape[0]),
|
|
442
|
+
torch.arange(shape[1]),
|
|
443
|
+
torch.arange(shape[2]),
|
|
444
|
+
indexing='ij'
|
|
445
|
+
)
|
|
446
|
+
dist = ((z - center[0]) / radii[0]) ** 2 + \
|
|
447
|
+
((y - center[1]) / radii[1]) ** 2 + \
|
|
448
|
+
((x - center[2]) / radii[2]) ** 2
|
|
449
|
+
return dist <= 1.0
|
|
450
|
+
|
|
428
451
|
# %% ../nbs/03_vision_augment.ipynb 23
|
|
452
|
+
class _TioRandomCutout(tio.IntensityTransform):
|
|
453
|
+
"""TorchIO-compatible RandomCutout for patch-based workflows."""
|
|
454
|
+
|
|
455
|
+
def __init__(self, holes=1, spatial_size=8, fill_value=None,
|
|
456
|
+
max_holes=None, max_spatial_size=None, p=0.2, **kwargs):
|
|
457
|
+
super().__init__(p=p, **kwargs)
|
|
458
|
+
self.holes = holes
|
|
459
|
+
self.spatial_size = spatial_size
|
|
460
|
+
self.fill_value = fill_value
|
|
461
|
+
self.max_holes = max_holes
|
|
462
|
+
self.max_spatial_size = max_spatial_size
|
|
463
|
+
|
|
464
|
+
def _apply_cutout(self, data, fill_val):
|
|
465
|
+
"""Apply spherical cutout(s) to a tensor."""
|
|
466
|
+
result = data.clone()
|
|
467
|
+
n_holes = torch.randint(self.holes, (self.max_holes or self.holes) + 1, (1,)).item()
|
|
468
|
+
|
|
469
|
+
spatial_shape = data.shape[1:] # (D, H, W)
|
|
470
|
+
min_size = self.spatial_size if isinstance(self.spatial_size, int) else self.spatial_size[0]
|
|
471
|
+
max_size = self.max_spatial_size or self.spatial_size
|
|
472
|
+
max_size = max_size if isinstance(max_size, int) else max_size[0]
|
|
473
|
+
|
|
474
|
+
for _ in range(n_holes):
|
|
475
|
+
# Random size for this hole
|
|
476
|
+
size = torch.randint(min_size, max_size + 1, (3,))
|
|
477
|
+
radii = size.float() / 2
|
|
478
|
+
|
|
479
|
+
# Random center (ensure hole fits in volume)
|
|
480
|
+
center = [
|
|
481
|
+
torch.randint(int(radii[i].item()),
|
|
482
|
+
max(spatial_shape[i] - int(radii[i].item()), int(radii[i].item()) + 1),
|
|
483
|
+
(1,)).item()
|
|
484
|
+
for i in range(3)
|
|
485
|
+
]
|
|
486
|
+
|
|
487
|
+
mask = _create_ellipsoid_mask(spatial_shape, center, radii)
|
|
488
|
+
result[:, mask] = fill_val
|
|
489
|
+
|
|
490
|
+
return result
|
|
491
|
+
|
|
492
|
+
def apply_transform(self, subject):
|
|
493
|
+
for image in self.get_images(subject):
|
|
494
|
+
data = image.data
|
|
495
|
+
fill_val = self.fill_value if self.fill_value is not None else float(data.min())
|
|
496
|
+
result = self._apply_cutout(data, fill_val)
|
|
497
|
+
image.set_data(result)
|
|
498
|
+
return subject
|
|
499
|
+
|
|
500
|
+
# %% ../nbs/03_vision_augment.ipynb 24
|
|
501
|
+
class RandomCutout(DisplayedTransform):
|
|
502
|
+
"""Randomly erase spherical regions in 3D medical images.
|
|
503
|
+
|
|
504
|
+
Simulates post-operative surgical cavities by filling random ellipsoid
|
|
505
|
+
volumes with specified values. Useful for training on pre-op images
|
|
506
|
+
to generalize to post-op scans.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
holes: Minimum number of cutout regions. Default: 1.
|
|
510
|
+
max_holes: Maximum number of regions. Default: 3.
|
|
511
|
+
spatial_size: Minimum cutout diameter in voxels. Default: 8.
|
|
512
|
+
max_spatial_size: Maximum cutout diameter. Default: 16.
|
|
513
|
+
fill: Fill value - 'min', 'mean', 'random', or float. Default: 'min'.
|
|
514
|
+
p: Probability of applying transform. Default: 0.2.
|
|
515
|
+
|
|
516
|
+
Example:
|
|
517
|
+
>>> # Simulate post-op cavities with dark spherical voids
|
|
518
|
+
>>> tfm = RandomCutout(holes=1, max_holes=2, spatial_size=10,
|
|
519
|
+
... max_spatial_size=25, fill='min', p=0.2)
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
split_idx, order = 0, 1
|
|
523
|
+
|
|
524
|
+
def __init__(self, holes=1, max_holes=3, spatial_size=8,
|
|
525
|
+
max_spatial_size=16, fill='min', p=0.2):
|
|
526
|
+
self.holes = holes
|
|
527
|
+
self.max_holes = max_holes
|
|
528
|
+
self.spatial_size = spatial_size
|
|
529
|
+
self.max_spatial_size = max_spatial_size
|
|
530
|
+
self.fill = fill
|
|
531
|
+
self.p = p
|
|
532
|
+
|
|
533
|
+
self._tio_cutout = _TioRandomCutout(
|
|
534
|
+
holes=holes, spatial_size=spatial_size,
|
|
535
|
+
fill_value=None if isinstance(fill, str) else fill,
|
|
536
|
+
max_holes=max_holes, max_spatial_size=max_spatial_size, p=p
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
@property
|
|
540
|
+
def tio_transform(self):
|
|
541
|
+
"""Return TorchIO-compatible transform for patch-based workflows."""
|
|
542
|
+
return self._tio_cutout
|
|
543
|
+
|
|
544
|
+
def _get_fill_value(self, tensor):
|
|
545
|
+
if self.fill == 'min': return float(tensor.min())
|
|
546
|
+
elif self.fill == 'mean': return float(tensor.mean())
|
|
547
|
+
elif self.fill == 'random':
|
|
548
|
+
return torch.empty(1).uniform_(float(tensor.min()), float(tensor.max())).item()
|
|
549
|
+
else: return self.fill
|
|
550
|
+
|
|
551
|
+
def encodes(self, o: MedImage):
|
|
552
|
+
if torch.rand(1).item() > self.p: return o
|
|
553
|
+
fill_val = self._get_fill_value(o)
|
|
554
|
+
result = self._tio_cutout._apply_cutout(o.clone(), fill_val)
|
|
555
|
+
return MedImage.create(result)
|
|
556
|
+
|
|
557
|
+
def encodes(self, o: MedMask):
|
|
558
|
+
return o
|
|
559
|
+
|
|
560
|
+
# %% ../nbs/03_vision_augment.ipynb 26
|
|
429
561
|
class RandomElasticDeformation(CustomDictTransform):
|
|
430
562
|
"""Apply TorchIO `RandomElasticDeformation`."""
|
|
431
563
|
|
|
@@ -438,7 +570,7 @@ class RandomElasticDeformation(CustomDictTransform):
|
|
|
438
570
|
image_interpolation=image_interpolation,
|
|
439
571
|
p=p))
|
|
440
572
|
|
|
441
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
573
|
+
# %% ../nbs/03_vision_augment.ipynb 27
|
|
442
574
|
class RandomAffine(CustomDictTransform):
|
|
443
575
|
"""Apply TorchIO `RandomAffine`."""
|
|
444
576
|
|
|
@@ -454,14 +586,14 @@ class RandomAffine(CustomDictTransform):
|
|
|
454
586
|
default_pad_value=default_pad_value,
|
|
455
587
|
p=p))
|
|
456
588
|
|
|
457
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
589
|
+
# %% ../nbs/03_vision_augment.ipynb 28
|
|
458
590
|
class RandomFlip(CustomDictTransform):
|
|
459
591
|
"""Apply TorchIO `RandomFlip`."""
|
|
460
592
|
|
|
461
593
|
def __init__(self, axes='LR', p=0.5):
|
|
462
594
|
super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))
|
|
463
595
|
|
|
464
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
596
|
+
# %% ../nbs/03_vision_augment.ipynb 29
|
|
465
597
|
class OneOf(CustomDictTransform):
|
|
466
598
|
"""Apply only one of the given transforms using TorchIO `OneOf`."""
|
|
467
599
|
|