brainmint 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. brainmint/__init__.py +12 -0
  2. brainmint/_version.py +24 -0
  3. brainmint/callbacks/log_images.py +341 -0
  4. brainmint/callbacks/modality_completion_schedule.py +315 -0
  5. brainmint/data/__init__.py +1 -0
  6. brainmint/data/brainscape.py +401 -0
  7. brainmint/data/brainscape_paired.py +945 -0
  8. brainmint/data/transforms/__init__.py +1 -0
  9. brainmint/data/transforms/conditioning.py +52 -0
  10. brainmint/data/transforms/copy_paths.py +44 -0
  11. brainmint/data/transforms/demographics.py +378 -0
  12. brainmint/data/transforms/intensity.py +50 -0
  13. brainmint/data/transforms/modality_choice.py +287 -0
  14. brainmint/data/transforms/mri_vae.py +308 -0
  15. brainmint/data/transforms/segmentation.py +478 -0
  16. brainmint/data/transforms/stream_mapping.py +226 -0
  17. brainmint/data/transforms/synthetic_mask.py +580 -0
  18. brainmint/data/utils.py +35 -0
  19. brainmint/external/__init__.py +20 -0
  20. brainmint/external/registry.py +78 -0
  21. brainmint/external/repo_manager.py +333 -0
  22. brainmint/external/sys_path.py +32 -0
  23. brainmint/inference/__init__.py +1 -0
  24. brainmint/inference/controlnet/__init__.py +5 -0
  25. brainmint/inference/controlnet/controlnet_inference.py +548 -0
  26. brainmint/inference/core/__init__.py +23 -0
  27. brainmint/inference/core/context.py +76 -0
  28. brainmint/inference/core/interfaces.py +70 -0
  29. brainmint/inference/core/runner.py +93 -0
  30. brainmint/inference/core/scheduler.py +98 -0
  31. brainmint/inference/diffusion/__init__.py +2 -0
  32. brainmint/inference/diffusion/conditioning/__init__.py +2 -0
  33. brainmint/inference/diffusion/conditioning/base.py +67 -0
  34. brainmint/inference/diffusion/conditioning/common/__init__.py +1 -0
  35. brainmint/inference/diffusion/conditioning/common/class_labels.py +83 -0
  36. brainmint/inference/diffusion/conditioning/common/constant.py +155 -0
  37. brainmint/inference/diffusion/conditioning/common/demographics.py +85 -0
  38. brainmint/inference/diffusion/conditioning/common/random_vec.py +35 -0
  39. brainmint/inference/diffusion/conditioning/common/vector_ops.py +33 -0
  40. brainmint/inference/diffusion/conditioning/ldm/__init__.py +1 -0
  41. brainmint/inference/diffusion/conditioning/ldm/fixed_4vec.py +41 -0
  42. brainmint/inference/diffusion/latent/__init__.py +2 -0
  43. brainmint/inference/diffusion/latent/base.py +7 -0
  44. brainmint/inference/diffusion/latent/common/__init__.py +1 -0
  45. brainmint/inference/diffusion/latent/common/fixed_shape.py +54 -0
  46. brainmint/inference/diffusion/latent/common/from_autoencoder_encode.py +43 -0
  47. brainmint/inference/diffusion/latent/common/from_batch.py +30 -0
  48. brainmint/inference/diffusion/latent/common/from_image_shape.py +52 -0
  49. brainmint/inference/diffusion/pipelines/__init__.py +1 -0
  50. brainmint/inference/diffusion/pipelines/base.py +121 -0
  51. brainmint/inference/diffusion/pipelines/generation/__init__.py +1 -0
  52. brainmint/inference/diffusion/pipelines/generation/ldm/__init__.py +1 -0
  53. brainmint/inference/diffusion/pipelines/generation/ldm/ukb_ddim.py +14 -0
  54. brainmint/inference/diffusion/samplers/__init__.py +2 -0
  55. brainmint/inference/diffusion/samplers/base.py +85 -0
  56. brainmint/inference/diffusion/samplers/common/cond_unet.py +69 -0
  57. brainmint/inference/diffusion/samplers/common/monai_unet.py +240 -0
  58. brainmint/inference/diffusion/samplers/ldm/__init__.py +1 -0
  59. brainmint/inference/diffusion/samplers/ldm/ukb_ddim.py +50 -0
  60. brainmint/inference/dynamic_inference.py +33 -0
  61. brainmint/inference/generation/__init__.py +1 -0
  62. brainmint/inference/generation/batch_builders/__init__.py +1 -0
  63. brainmint/inference/generation/batch_builders/med_ddpm.py +50 -0
  64. brainmint/inference/generation/pipelines/__init__.py +1 -0
  65. brainmint/inference/generation/pipelines/common.py +6 -0
  66. brainmint/inference/generation/pipelines/external.py +13 -0
  67. brainmint/inference/generation/pipelines/hagan.py +43 -0
  68. brainmint/inference/generation/pipelines/maisi.py +52 -0
  69. brainmint/inference/generation/pipelines/med_ddpm.py +89 -0
  70. brainmint/inference/generation/pipelines/wdm3d.py +58 -0
  71. brainmint/inference/io/__init__.py +5 -0
  72. brainmint/inference/io/base.py +33 -0
  73. brainmint/inference/io/dataset_writers.py +125 -0
  74. brainmint/inference/io/readers.py +56 -0
  75. brainmint/inference/io/writers.py +191 -0
  76. brainmint/inference/postprocess/__init__.py +3 -0
  77. brainmint/inference/postprocess/base.py +13 -0
  78. brainmint/inference/postprocess/brats_pipeline.py +317 -0
  79. brainmint/inference/postprocess/reorient.py +101 -0
  80. brainmint/inference/translation/__init__.py +17 -0
  81. brainmint/inference/translation/generators/__init__.py +12 -0
  82. brainmint/inference/translation/generators/aldm.py +7 -0
  83. brainmint/inference/translation/generators/cwdm.py +7 -0
  84. brainmint/inference/translation/pipelines/__init__.py +1 -0
  85. brainmint/integrations/__init__.py +1 -0
  86. brainmint/integrations/aldm/__init__.py +10 -0
  87. brainmint/integrations/aldm/ldm.py +47 -0
  88. brainmint/integrations/aldm/repo.py +132 -0
  89. brainmint/integrations/aldm/vqgan.py +136 -0
  90. brainmint/integrations/brainsynth/__init__.py +7 -0
  91. brainmint/integrations/brainsynth/inferer.py +91 -0
  92. brainmint/integrations/brainsynth/vendor_vqvae.py +569 -0
  93. brainmint/integrations/brasyn/__init__.py +7 -0
  94. brainmint/integrations/brasyn/io.py +154 -0
  95. brainmint/integrations/brasyn/missing_mri.py +181 -0
  96. brainmint/integrations/brasyn/modalities.py +71 -0
  97. brainmint/integrations/brasyn/runtime.py +347 -0
  98. brainmint/integrations/cwdm/__init__.py +8 -0
  99. brainmint/integrations/cwdm/repo.py +31 -0
  100. brainmint/integrations/cwdm/translator.py +103 -0
  101. brainmint/integrations/hagan/__init__.py +8 -0
  102. brainmint/integrations/hagan/generator.py +35 -0
  103. brainmint/integrations/hagan/repo.py +31 -0
  104. brainmint/integrations/maisi/__init__.py +8 -0
  105. brainmint/integrations/maisi/autoencoder.py +15 -0
  106. brainmint/integrations/maisi/generator.py +301 -0
  107. brainmint/integrations/maisi/repo.py +31 -0
  108. brainmint/integrations/med_ddpm/__init__.py +8 -0
  109. brainmint/integrations/med_ddpm/generator.py +68 -0
  110. brainmint/integrations/med_ddpm/repo.py +31 -0
  111. brainmint/integrations/wdm3d/__init__.py +8 -0
  112. brainmint/integrations/wdm3d/generator.py +148 -0
  113. brainmint/integrations/wdm3d/repo.py +31 -0
  114. brainmint/lightning/__init__.py +7 -0
  115. brainmint/lightning/controlnet_module.py +1377 -0
  116. brainmint/lightning/diffusion_inference_module.py +110 -0
  117. brainmint/lightning/diffusion_module.py +1250 -0
  118. brainmint/lightning/export_latents_module.py +135 -0
  119. brainmint/lightning/generic_inference_module.py +78 -0
  120. brainmint/lightning/vae_module.py +324 -0
  121. brainmint/losses/__init__.py +1 -0
  122. brainmint/losses/mask_recon_loss.py +100 -0
  123. brainmint/losses/utils.py +23 -0
  124. brainmint/losses/vae_loss_manager.py +272 -0
  125. brainmint/metrics/__init__.py +1 -0
  126. brainmint/metrics/diffusion_metrics.py +552 -0
  127. brainmint/metrics/reconstruction.py +462 -0
  128. brainmint/models/__init__.py +1 -0
  129. brainmint/models/blocks/__init__.py +8 -0
  130. brainmint/models/blocks/haar_dwt.py +205 -0
  131. brainmint/models/blocks/haar_wavelet_fusion.py +187 -0
  132. brainmint/models/compression/__init__.py +8 -0
  133. brainmint/models/compression/aldm_vqgan.py +37 -0
  134. brainmint/models/compression/brainsynth_vqvae.py +59 -0
  135. brainmint/models/compression/dwt.py +52 -0
  136. brainmint/models/compression/ldm_vae.py +61 -0
  137. brainmint/models/compression/maisi_vae_gan.py +53 -0
  138. brainmint/models/compression/wavelet_fusion.py +714 -0
  139. brainmint/models/compression/wavelet_vae.py +500 -0
  140. brainmint/models/conditioning/demographics_encoder.py +241 -0
  141. brainmint/models/controlnet/__init__.py +5 -0
  142. brainmint/models/controlnet/controlnet.py +195 -0
  143. brainmint/models/controlnet/controlnet_monai.py +472 -0
  144. brainmint/models/generation/__init__.py +8 -0
  145. brainmint/models/generation/diffusion_unet.py +411 -0
  146. brainmint/models/generation/hagan.py +106 -0
  147. brainmint/models/generation/maisi.py +134 -0
  148. brainmint/models/generation/med_ddpm.py +125 -0
  149. brainmint/models/generation/wdm3d.py +166 -0
  150. brainmint/models/schedulers/__init__.py +1 -0
  151. brainmint/models/schedulers/rflow_scheduler.py +23 -0
  152. brainmint/models/translation/__init__.py +8 -0
  153. brainmint/models/translation/aldm.py +347 -0
  154. brainmint/models/translation/brasyn.py +45 -0
  155. brainmint/models/translation/cwdm.py +160 -0
  156. brainmint/models/translation/utils.py +106 -0
  157. brainmint/py.typed +0 -0
  158. brainmint/utils/__init__.py +1 -0
  159. brainmint/utils/batch.py +32 -0
  160. brainmint/utils/ema.py +329 -0
  161. brainmint/utils/gpumem_utils.py +50 -0
  162. brainmint/utils/schedules.py +199 -0
  163. brainmint/utils/spatial.py +71 -0
  164. brainmint/utils/state_dict_loader.py +371 -0
  165. brainmint/visualization/__init__.py +3 -0
  166. brainmint/visualization/slices.py +46 -0
  167. brainmint-0.1.0.dist-info/METADATA +109 -0
  168. brainmint-0.1.0.dist-info/RECORD +171 -0
  169. brainmint-0.1.0.dist-info/WHEEL +5 -0
  170. brainmint-0.1.0.dist-info/licenses/LICENSE +21 -0
  171. brainmint-0.1.0.dist-info/top_level.txt +1 -0
brainmint/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ """BrainMint models, integrations, and inference tools for medical image synthesis."""
2
+
3
+ from importlib.metadata import PackageNotFoundError
4
+ from importlib.metadata import version as _metadata_version
5
+
6
+ try:
7
+ from ._version import __version__
8
+ except (ImportError, ModuleNotFoundError):
9
+ try:
10
+ __version__ = _metadata_version("brainmint")
11
+ except PackageNotFoundError:
12
+ __version__ = "0.0.0"
brainmint/_version.py ADDED
@@ -0,0 +1,24 @@
1
+ # file generated by vcs-versioning
2
+ # don't change, don't track in version control
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "__version_tuple__",
8
+ "version",
9
+ "version_tuple",
10
+ "__commit_id__",
11
+ "commit_id",
12
+ ]
13
+
14
+ version: str
15
+ __version__: str
16
+ __version_tuple__: tuple[int | str, ...]
17
+ version_tuple: tuple[int | str, ...]
18
+ commit_id: str | None
19
+ __commit_id__: str | None
20
+
21
+ __version__ = version = '0.1.0'
22
+ __version_tuple__ = version_tuple = (0, 1, 0)
23
+
24
+ __commit_id__ = commit_id = None
@@ -0,0 +1,341 @@
1
+ import gc
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Optional, Sequence, Callable, Union
5
+
6
+ import torch
7
+ import contextlib
8
+ import pytorch_lightning as pl
9
+ from monai.transforms import SaveImage
10
+ from monai.data import MetaTensor
11
+
12
+ _LOG = logging.getLogger(__name__)
13
+
14
+
15
+ def get_tb_writer(trainer):
16
+ logs = trainer.loggers if getattr(trainer, "loggers", None) else ([trainer.logger] if getattr(trainer, "logger", None) else [])
17
+ for lg in logs:
18
+ if isinstance(lg, pl.loggers.TensorBoardLogger):
19
+ return lg.experiment # SummaryWriter
20
+ return None
21
+
22
+
23
+ class SaveMRIImages(pl.Callback):
24
+ """
25
+ After each validation epoch:
26
+ • iterate the val dataloader,
27
+ • collect ONE sample per requested modality (e.g., T1w/T2w/T1ce/FLAIR),
28
+ • run module inference (preferred) or forward() as fallback,
29
+ • save paired NIfTI files into: <run_dir>/<subdir>/epoch-XXX/.
30
+
31
+ Filenames:
32
+ <tag>_<MODALITY>_input.nii.gz
33
+ <tag>_<MODALITY>_<OUTNAME>.nii.gz (for each selected model output)
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ modalities: Sequence[str] = ("T1w", "T2w", "T1ce", "FLAIR"),
39
+ tag: str = "val",
40
+ subdir: str = "val_samples",
41
+ dirpath: Optional[str] = None, # Defaults to trainer.log_dir
42
+ infer_kwarg_keys: Sequence[str] = ("image",), # Batch keys to pass to inference
43
+ infer_method_candidates: Sequence[str] = ("_run_inference",),
44
+ fallback_to_forward: bool = True,
45
+ output_names: Sequence[str] = ("output", "z_mu", "z_sigma"), # Names for outputs, matching the order the model returns.
46
+ save_outputs: Sequence[Union[int, bool, str]] = (1, 0, 0), # Output to save, Boolean Mask | Indices | Output Names
47
+ input_save_list: Optional[Sequence[str]] = None,
48
+ save_dtype: torch.dtype = torch.float32,
49
+ clamp_min: float = 0.0,
50
+ clamp_max: float = 1.0,
51
+ output_activation: Optional[str] = None,
52
+ separate_folder: bool = False,
53
+ dataset_module: Optional[Any] = None,
54
+
55
+
56
+ ) -> None:
57
+ super().__init__()
58
+ self.modalities = [m.lower() for m in modalities]
59
+ self.tag = tag
60
+ self.subdir = subdir
61
+ self.dirpath = dirpath
62
+ self.infer_method_candidates = list(infer_method_candidates)
63
+ self.fallback_to_forward = fallback_to_forward
64
+ self.infer_kwarg_keys = list(infer_kwarg_keys)
65
+ self.input_save_list = list(input_save_list) if input_save_list is not None else []
66
+
67
+ self.output_names = list(output_names)
68
+ self.save_indices = self._normalize_save_selector(save_outputs, self.output_names)
69
+
70
+ self.save_dtype = save_dtype if isinstance(save_dtype, torch.dtype) else getattr(
71
+ torch, str(save_dtype).strip().lower().removeprefix("torch.")
72
+ )
73
+ self.clamp_min = clamp_min
74
+ self.clamp_max = clamp_max
75
+ self.output_activation = output_activation.lower() if isinstance(output_activation, str) and output_activation else None
76
+ if self.output_activation not in (None, "sigmoid", "tanh"):
77
+ raise ValueError(f"SaveMRIImages: unsupported output_activation={output_activation}")
78
+ self.separate_folder = separate_folder
79
+ self.dataset_module = dataset_module
80
+
81
+ @staticmethod
82
+ def _normalize_save_selector(selector: Sequence[Union[int, bool, str]], names: Sequence[str]) -> List[int]:
83
+ """Normalize save_outputs → sorted unique indices."""
84
+ n = len(names)
85
+ lower_names = {nm.lower() for nm in names}
86
+ name_map = {nm.lower(): i for i, nm in enumerate(names)}
87
+
88
+ is_bool_mask = (n > 0 and len(selector) == n) and all(
89
+ isinstance(x, bool) or (isinstance(x, int) and x in (0, 1)) for x in selector
90
+ )
91
+ if is_bool_mask:
92
+ idxs = [i for i, flag in enumerate(selector) if bool(flag)]
93
+ return sorted(set(idxs))
94
+
95
+ is_indices = len(selector) <= n and all(isinstance(x, int) and 0 <= x < n for x in selector)
96
+ if is_indices:
97
+ return sorted(set(selector))
98
+
99
+ is_names = len(selector) <= n and all(isinstance(x, str) and x.lower() in lower_names for x in selector)
100
+ idxs: List[int] = []
101
+ if is_names:
102
+ for s in selector:
103
+ j = name_map.get(s.lower())
104
+ if j is not None:
105
+ idxs.append(j)
106
+
107
+ if not idxs:
108
+ _LOG.error(f"SaveMRIImages Callback - Invalid output selector after filtering: {selector}")
109
+ return sorted(set(idxs))
110
+
111
+ def _resolve_infer_fn(self, pl_module: pl.LightningModule) -> Optional[Callable]:
112
+ for name in self.infer_method_candidates:
113
+ fn = getattr(pl_module, name, None)
114
+ if callable(fn):
115
+ _LOG.info(f"SaveMRIImages Callback - Using inference method '{name}'.")
116
+ return fn
117
+ if self.fallback_to_forward and hasattr(pl_module, "forward"):
118
+ _LOG.warning(
119
+ "SaveMRIImages Callback - None of %s found; falling back to forward().",
120
+ self.infer_method_candidates,
121
+ )
122
+ return pl_module.forward
123
+ _LOG.warning("SaveMRIImages Callback - No inference method and fallback disabled; will skip saving.")
124
+ return None
125
+
126
+ def _slice_like_batch(self, val: Any, batch_size: int, i: int):
127
+ """Return the i-th sample if batched along dim 0; else return as-is."""
128
+ if torch.is_tensor(val):
129
+ if val.dim() > 0 and val.size(0) == batch_size:
130
+ return val[i:i + 1]
131
+ return val
132
+ if isinstance(val, (list, tuple)) and len(val) == batch_size:
133
+ return val[i]
134
+ return val
135
+
136
+ def _build_infer_kwargs(
137
+ self,
138
+ pl_module: pl.LightningModule,
139
+ batch: Dict[str, Any],
140
+ batch_size: int,
141
+ i: int,
142
+ ) -> Dict[str, Any]:
143
+ """Build kwargs from configured keys; first entry is treated as input."""
144
+ if not self.infer_kwarg_keys:
145
+ return {}
146
+ kwargs: Dict[str, Any] = {}
147
+ for k in self.infer_kwarg_keys:
148
+ if k in batch:
149
+ v = self._slice_like_batch(batch[k], batch_size, i)
150
+ if torch.is_tensor(v):
151
+ v = v.to(pl_module.device, non_blocking=True)
152
+ kwargs[k] = v
153
+ else:
154
+ raise KeyError(f"Key:{k} missing in batch, Available keys:{list(batch.keys())}")
155
+ return {kk: vv for kk, vv in kwargs.items() if vv is not None}
156
+
157
+ @torch.no_grad()
158
+ def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
159
+ # Only the main/global process writes
160
+ if not trainer.is_global_zero:
161
+ return
162
+
163
+ run_dir = Path(self.dirpath) if self.dirpath else Path(trainer.log_dir or trainer.default_root_dir)
164
+ epoch_dir = run_dir / self.subdir / f"epoch-{int(pl_module.current_epoch):03d}"
165
+ epoch_dir.mkdir(parents=True, exist_ok=True)
166
+
167
+ infer_fn = self._resolve_infer_fn(pl_module)
168
+ if infer_fn is None:
169
+ return
170
+
171
+ # Choose which datamodule/dataloader to use
172
+ if self.dataset_module is not None:
173
+ dm = self.dataset_module
174
+ try:
175
+ if hasattr(dm, "setup"): dm.setup()
176
+ except Exception as e:
177
+ _LOG.warning(f"... prepare/setup custom dataset module: {e}")
178
+ if dm is None:
179
+ _LOG.warning("SaveMRIImages Callback - No datamodule; skipping.")
180
+ return
181
+ vloader = dm.val_dataloader()
182
+ else:
183
+ dm = trainer.datamodule
184
+ if dm is None:
185
+ _LOG.warning("SaveMRIImages Callback - No datamodule; skipping.")
186
+ return
187
+ vloader = dm.val_dataloader()
188
+
189
+ want = set(self.modalities)
190
+ got: set[str] = set()
191
+ _LOG.info(
192
+ f"SaveMRIImages Callback - Collecting modalities {sorted(want)} for epoch {int(pl_module.current_epoch)}"
193
+ )
194
+
195
+ required = set(self.infer_kwarg_keys)
196
+
197
+ for batch in vloader:
198
+ if want == got:
199
+ break
200
+
201
+ missing = required - set(batch.keys())
202
+ if missing:
203
+ raise KeyError(
204
+ f"SaveMRIImages: missing required keys: {', '.join(sorted(missing))}. Present: {', '.join(sorted(batch.keys()))}"
205
+ )
206
+
207
+ mods = batch.get("modality", None)
208
+ batch_size = int(batch[self.infer_kwarg_keys[0]].shape[0]) # Shape of First kwarg
209
+ missing_modality = mods is None
210
+ if missing_modality:
211
+ _LOG.warning(
212
+ "SaveMRIImages Callback - Batch missing 'modality'; saving first available samples."
213
+ )
214
+ mod_list = [None] * batch_size
215
+ else:
216
+ mod_list = [m.lower() for m in batch["modality"]]
217
+
218
+ remaining = len(want - got)
219
+ saved_without_modality = 0
220
+ for i in range(batch_size):
221
+ if want == got:
222
+ break
223
+
224
+ mod_i = mod_list[i]
225
+ if mod_i is not None:
226
+ mod_i = mod_i.lower()
227
+ if (mod_i not in want) or (mod_i in got):
228
+ continue
229
+ mod_label = mod_i
230
+ else:
231
+ if saved_without_modality >= remaining:
232
+ break
233
+ mod_label = f"sample{i:03d}"
234
+
235
+ infer_kwargs = self._build_infer_kwargs(
236
+ pl_module=pl_module,
237
+ batch=batch,
238
+ batch_size=batch_size,
239
+ i=i,
240
+ )
241
+
242
+ plugin = getattr(trainer.strategy, "precision_plugin", None)
243
+ ctx = plugin.forward_context() if plugin else contextlib.nullcontext()
244
+ try:
245
+ with ctx, torch.no_grad():
246
+ result = infer_fn(infer_kwargs)
247
+ except Exception as e:
248
+ kw_types = {k: type(v).__name__ for k, v in infer_kwargs.items()}
249
+ raise RuntimeError(
250
+ f"SaveMRIImages: inference failed for modality '{mod_i}'. "
251
+ f"Kwarg types: {kw_types}"
252
+ ) from e
253
+
254
+ if isinstance(result, dict):
255
+ outputs = [result.get(nm, None) for nm in self.output_names]
256
+ elif isinstance(result, (tuple, list)):
257
+ outputs = list(result)
258
+ else: # single tensor
259
+ outputs = [result]
260
+
261
+ # Prepare & save input once
262
+ save_inputs = []
263
+ if not self.input_save_list:
264
+ save_inputs.append(self.infer_kwarg_keys[0]) # Fist Key (infer_kwarg_keys) is the input
265
+ else:
266
+ save_inputs = self.input_save_list
267
+
268
+ for inp_name in save_inputs:
269
+ inp = infer_kwargs[inp_name]
270
+ if inp.shape[0] != 1:
271
+ raise ValueError(f"SaveMRIImages - Input dimension error, Length of 1st Dim should be 1. Got:{inp.shape} ")
272
+ inp_vol = inp[0].detach().to(dtype=self.save_dtype).clamp(self.clamp_min, self.clamp_max).cpu()
273
+ if not isinstance(inp_vol, MetaTensor):
274
+ inp_vol = MetaTensor(inp_vol, meta={"filename_or_obj": f"_.nii.gz"})
275
+ input_saver = SaveImage(
276
+ output_dir=str(epoch_dir),
277
+ output_postfix=f"{self.tag}_{mod_label}_input_{inp_name}",
278
+ output_ext=".nii.gz",
279
+ separate_folder=self.separate_folder,
280
+ )
281
+ input_saver(inp_vol)
282
+
283
+ del input_saver, inp_vol
284
+ gc.collect()
285
+
286
+ for j in self.save_indices:
287
+ if j >= len(outputs) or outputs[j] is None:
288
+ continue
289
+ out = outputs[j]
290
+ out_vol = out[0].detach()
291
+ if self.output_activation == "sigmoid":
292
+ out_vol = torch.sigmoid(out_vol)
293
+ elif self.output_activation == "tanh":
294
+ out_vol = torch.tanh(out_vol)
295
+ out_vol = out_vol.to(self.save_dtype).clamp(self.clamp_min, self.clamp_max).cpu()
296
+
297
+ if isinstance(out_vol, MetaTensor):
298
+ out_meta_vol = out_vol
299
+ else:
300
+
301
+ #TODO: FIX PASSING OF INPUT META INFO
302
+ out_meta = {}#getattr(inp_vol, "meta", None) or {}
303
+ out_meta_vol = MetaTensor(out_vol, meta=out_meta)
304
+
305
+ out_name = self.output_names[j] if j < len(self.output_names) else f"out{j}"
306
+ saver = SaveImage(
307
+ output_dir=str(epoch_dir),
308
+ output_postfix=f"{self.tag}_{mod_label}_output_{out_name}",
309
+ output_ext=".nii.gz",
310
+ separate_folder=self.separate_folder,
311
+ )
312
+ saver(out_meta_vol)
313
+
314
+ del saver, out_meta_vol
315
+ gc.collect()
316
+
317
+ if mod_i is None:
318
+ saved_without_modality += 1
319
+ if saved_without_modality >= remaining:
320
+ got = set(want)
321
+ else:
322
+ got.add(mod_i)
323
+ _LOG.info(
324
+ f"SaveMRIImages Callback - Saved {mod_label} (input & {','.join(self.output_names[k] for k in self.save_indices if k < len(self.output_names))}) → {epoch_dir}"
325
+ )
326
+
327
+ del infer_kwargs, outputs
328
+ gc.collect()
329
+
330
+ missing = sorted(want - got)
331
+ if missing:
332
+ _LOG.warning(f"SaveMRIImages Callback - Missing modalities this epoch: {missing}")
333
+ else:
334
+ _LOG.info(f"SaveMRIImages Callback - Saved all requested modalities: {sorted(got)}")
335
+
336
+
337
+ # After saving, if we used a custom dataset, tear it down and free memory
338
+ if self.dataset_module is not None:
339
+ del vloader, dm
340
+ gc.collect()
341
+ torch.cuda.empty_cache()
@@ -0,0 +1,315 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence
5
+
6
+ import pytorch_lightning as pl
7
+
8
+ from brainmint.utils.schedules import PiecewiseSchedule
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def _normalize_probs(probs: Mapping[str, float]) -> Dict[str, float]:
14
+ out = {str(k): float(v) for k, v in dict(probs).items()}
15
+ s = float(sum(max(0.0, v) for v in out.values()))
16
+ if s <= 0.0:
17
+ raise ValueError(f"Cannot normalize probabilities with non-positive sum: {out}")
18
+ return {k: max(0.0, v) / s for k, v in out.items()}
19
+
20
+
21
+ def _iter_transforms(root: Any) -> Iterable[Any]:
22
+ """Depth-first walk over transform graphs (matches tests/data/test_translation_configs.py)."""
23
+ if root is None:
24
+ return
25
+ stack = [root]
26
+ seen = set()
27
+ while stack:
28
+ current = stack.pop()
29
+ if id(current) in seen:
30
+ continue
31
+ seen.add(id(current))
32
+ yield current
33
+
34
+ children: List[Any] = []
35
+ nested = getattr(current, "_transform", None)
36
+ if nested is not None and nested is not current:
37
+ children.append(nested)
38
+
39
+ transforms = getattr(current, "transforms", None)
40
+ if transforms:
41
+ children.extend(list(transforms))
42
+
43
+ extra_start = getattr(current, "extra_xforms_start", None)
44
+ if extra_start:
45
+ children.extend(list(extra_start))
46
+
47
+ extra_end = getattr(current, "extra_xforms_end", None)
48
+ if extra_end:
49
+ children.extend(list(extra_end))
50
+
51
+ stack.extend(reversed(children))
52
+
53
+
54
+ def _find_choice_states(dm: Any) -> List[Any]:
55
+ st: List[Any] = []
56
+ tf = getattr(dm, "train_tf", None)
57
+ root = getattr(tf, "transform", None) or tf
58
+ for t in _iter_transforms(root):
59
+ state = getattr(t, "state", None)
60
+ if state is not None and hasattr(state, "get_choices") and hasattr(state, "set_choices"):
61
+ st.append(state)
62
+ return st
63
+
64
+
65
+ def _find_sampling_states(dm: Any) -> List[Any]:
66
+ st: List[Any] = []
67
+ tf = getattr(dm, "train_tf", None)
68
+ root = getattr(tf, "transform", None) or tf
69
+ for t in _iter_transforms(root):
70
+ ss = getattr(t, "sampling_state", None)
71
+ if ss is not None and hasattr(ss, "get_config") and hasattr(ss, "set_config"):
72
+ st.append(ss)
73
+ return st
74
+
75
+
76
+ class ModalityCompletionScheduleCallback(pl.Callback):
77
+ """
78
+ Generic, config-driven epoch-boundary scheduler for modality completion experiments.
79
+
80
+ What it can update:
81
+ 1) Bucket sampler probabilities via DataModule.set_bucket_probs(...)
82
+ 2) Stream selection probabilities via SharedChoiceState.set_choices(...)
83
+ 3) Partial sampling config via SharedSamplingState.set_config(...)
84
+
85
+ All behavior is driven by YAML config (step/linear schedules). No stage hardcoding.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ *,
91
+ bucket_schedule: Optional[Sequence[Mapping[str, Any]]] = None,
92
+ choice_schedule: Optional[Sequence[Mapping[str, Any]]] = None,
93
+ sampling_schedule: Optional[Sequence[Mapping[str, Any]]] = None,
94
+ update_on: str = "epoch_start",
95
+ strict: bool = True,
96
+ normalize_bucket_probs: bool = True,
97
+ normalize_choice_probs: bool = True,
98
+ auto_complement_two_way: bool = True,
99
+ log_updates: bool = True,
100
+ ) -> None:
101
+ super().__init__()
102
+ self.update_on = str(update_on).lower().strip()
103
+ if self.update_on not in ("epoch_start", "epoch_end"):
104
+ raise ValueError("update_on must be one of: 'epoch_start', 'epoch_end'")
105
+
106
+ self.strict = bool(strict)
107
+ self.normalize_bucket_probs = bool(normalize_bucket_probs)
108
+ self.normalize_choice_probs = bool(normalize_choice_probs)
109
+ self.auto_complement_two_way = bool(auto_complement_two_way)
110
+ self.log_updates = bool(log_updates)
111
+
112
+ self.bucket_sched = PiecewiseSchedule(bucket_schedule, name="bucket_schedule")
113
+ self.choice_sched = PiecewiseSchedule(choice_schedule, name="choice_schedule")
114
+ self.sampling_sched = PiecewiseSchedule(sampling_schedule, name="sampling_schedule")
115
+
116
+ self._choice_states: List[Any] = []
117
+ self._sampling_states: List[Any] = []
118
+
119
+ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
120
+ if stage not in (None, "fit"):
121
+ return
122
+ dm = trainer.datamodule
123
+ self._choice_states = _find_choice_states(dm)
124
+ self._sampling_states = _find_sampling_states(dm)
125
+
126
+ if self.strict:
127
+ if (self.choice_sched.steps or self.choice_sched.lines) and not self._choice_states:
128
+ raise RuntimeError("choice_schedule provided but no SharedChoiceState found in train transforms.")
129
+ if (self.sampling_sched.steps or self.sampling_sched.lines) and not self._sampling_states:
130
+ raise RuntimeError("sampling_schedule provided but no SharedSamplingState found in train transforms.")
131
+
132
+ def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
133
+ if self.update_on == "epoch_start":
134
+ self._apply(trainer)
135
+
136
+ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
137
+ if self.update_on == "epoch_end":
138
+ self._apply(trainer)
139
+
140
+ def _apply(self, trainer: pl.Trainer) -> None:
141
+ epoch = int(trainer.current_epoch)
142
+ dm = trainer.datamodule
143
+
144
+ bucket_probs = self.bucket_sched.value_at(epoch)
145
+ choice_updates = self.choice_sched.value_at(epoch)
146
+ sampling_updates = self.sampling_sched.value_at(epoch)
147
+
148
+ if bucket_probs is not None:
149
+ self._apply_bucket_probs(dm, bucket_probs, epoch)
150
+
151
+ if choice_updates is not None:
152
+ self._apply_choice_probs(choice_updates, epoch)
153
+
154
+ if sampling_updates is not None:
155
+ self._apply_sampling_config(sampling_updates)
156
+
157
+ # ---------------------------
158
+ # Bucket probs
159
+ # ---------------------------
160
+ def _apply_bucket_probs(self, dm: Any, value: Any, epoch: int) -> None:
161
+ if not isinstance(value, Mapping):
162
+ raise TypeError(f"bucket_schedule value must be a mapping bucket->prob, got {type(value)}")
163
+
164
+ probs = {str(k): float(v) for k, v in dict(value).items()}
165
+ if self.normalize_bucket_probs:
166
+ probs = _normalize_probs(probs)
167
+
168
+ setter = getattr(dm, "set_bucket_probs", None)
169
+ if not callable(setter):
170
+ msg = "DataModule has no set_bucket_probs(); cannot apply bucket_schedule."
171
+ if self.strict:
172
+ raise RuntimeError(msg)
173
+ logger.warning(msg)
174
+ return
175
+
176
+ setter(probs)
177
+
178
+ if self.log_updates:
179
+ logger.info("[Schedule] bucket_probs(epoch=%d): %s", int(epoch), probs)
180
+
181
+ # ---------------------------
182
+ # Choice probs (SharedChoiceState)
183
+ # ---------------------------
184
+ def _apply_choice_probs(self, value: Any, epoch: int) -> None:
185
+ if not isinstance(value, Mapping):
186
+ raise TypeError(f"choice_schedule value must be a nested mapping, got {type(value)}")
187
+
188
+ updates: Dict[str, Any] = dict(value)
189
+
190
+ for state in self._choice_states:
191
+ cur = dict(state.get_choices())
192
+ changed = False
193
+
194
+ for bucket, mods in updates.items():
195
+ bucket = str(bucket)
196
+ if bucket not in cur:
197
+ if self.strict:
198
+ raise KeyError(f"choice_schedule refers to unknown bucket '{bucket}'")
199
+ logger.warning("choice_schedule: unknown bucket '%s' (ignored)", bucket)
200
+ continue
201
+ if not isinstance(mods, Mapping):
202
+ if self.strict:
203
+ raise TypeError(f"choice_schedule[{bucket}] must be a mapping modality->alias_probs")
204
+ continue
205
+
206
+ bcfg = cur[bucket]
207
+ if not isinstance(bcfg, Mapping):
208
+ if self.strict:
209
+ raise TypeError(f"choices[{bucket}] must be mapping")
210
+ continue
211
+
212
+ for mod, alias_probs in mods.items():
213
+ mod_key = str(mod).lower()
214
+ if not isinstance(alias_probs, Mapping):
215
+ if self.strict:
216
+ raise TypeError(f"choice_schedule[{bucket}][{mod_key}] must be mapping alias->prob")
217
+ continue
218
+
219
+ # Find modality config; fall back to "*"
220
+ mcfg = bcfg.get(mod_key, None)
221
+ wildcard = False
222
+ if mcfg is None:
223
+ mcfg = bcfg.get("*", None)
224
+ wildcard = True
225
+ if mcfg is None:
226
+ if self.strict:
227
+ raise KeyError(f"choices missing config for bucket='{bucket}' modality='{mod_key}' (and no '*')")
228
+ logger.warning("choices missing config for bucket='%s' modality='%s' (ignored)", bucket, mod_key)
229
+ continue
230
+ if not isinstance(mcfg, Mapping):
231
+ if self.strict:
232
+ raise TypeError(f"choices[{bucket}][{mod_key if not wildcard else '*'}] must be mapping")
233
+ continue
234
+
235
+ streams = dict(mcfg.get("streams") or {})
236
+ if not streams:
237
+ if self.strict:
238
+ raise KeyError(f"choices[{bucket}][{mod_key if not wildcard else '*'}].streams is empty")
239
+ continue
240
+
241
+ probs = dict(mcfg.get("probs") or {})
242
+
243
+ # apply updates (set)
244
+ for alias, p in alias_probs.items():
245
+ probs[str(alias)] = float(p)
246
+
247
+ if self.strict:
248
+ unknown_aliases = set(str(a) for a in alias_probs.keys()) - set(streams.keys())
249
+ if unknown_aliases:
250
+ raise KeyError(
251
+ f"choice_schedule[{bucket}][{mod_key}] refers to unknown stream aliases: "
252
+ f"{sorted(unknown_aliases)}"
253
+ )
254
+
255
+ # Optional complement for 2-way choices when only one alias was specified
256
+ if self.auto_complement_two_way:
257
+ defined_aliases = list(streams.keys())
258
+ specified = [str(a) for a in alias_probs.keys() if str(a) in defined_aliases]
259
+ if len(defined_aliases) == 2 and len(specified) == 1:
260
+ a0 = specified[0]
261
+ other = defined_aliases[0] if defined_aliases[1] == a0 else defined_aliases[1]
262
+ p0 = float(probs.get(a0, 0.0))
263
+ probs[other] = max(0.0, 1.0 - p0)
264
+
265
+ # Keep only aliases defined in streams
266
+ probs = {a: float(probs.get(a, 0.0)) for a in streams.keys()}
267
+
268
+ if self.normalize_choice_probs:
269
+ probs = _normalize_probs(probs)
270
+
271
+ new_mcfg = dict(mcfg)
272
+ new_mcfg["probs"] = probs
273
+
274
+ new_bucket_cfg = dict(cur[bucket])
275
+ if wildcard and "*" in new_bucket_cfg:
276
+ new_bucket_cfg["*"] = new_mcfg
277
+ else:
278
+ new_bucket_cfg[mod_key] = new_mcfg
279
+ cur[bucket] = new_bucket_cfg
280
+
281
+ changed = True
282
+
283
+ if changed:
284
+ try:
285
+ state.set_epoch(int(epoch))
286
+ except Exception:
287
+ pass
288
+ state.set_choices(cur)
289
+ if self.log_updates:
290
+ logger.info("[Schedule] choice_probs(epoch=%d): updated SharedChoiceState", epoch)
291
+
292
+ # ---------------------------
293
+ # Sampling config (SharedSamplingState)
294
+ # ---------------------------
295
+ def _apply_sampling_config(self, value: Any) -> None:
296
+ if not isinstance(value, Mapping):
297
+ raise TypeError(f"sampling_schedule value must be a mapping, got {type(value)}")
298
+ upd = {str(k): v for k, v in dict(value).items()}
299
+ if self.strict:
300
+ if "sigma_prob" in upd:
301
+ upd["sigma_prob"] = float(upd["sigma_prob"])
302
+ if not (0.0 <= upd["sigma_prob"] <= 1.0):
303
+ raise ValueError("sampling_schedule sigma_prob must be between 0 and 1.")
304
+ if "sigma_alpha" in upd:
305
+ upd["sigma_alpha"] = float(upd["sigma_alpha"])
306
+ if upd["sigma_alpha"] < 0.0:
307
+ raise ValueError("sampling_schedule sigma_alpha must be >= 0.")
308
+
309
+ for ss in self._sampling_states:
310
+ cur = dict(ss.get_config())
311
+ cur.update(upd)
312
+ ss.set_config(cur)
313
+
314
+ if self.log_updates:
315
+ logger.info("[Schedule] sampling_config: %s", upd)