brainmint 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainmint/__init__.py +12 -0
- brainmint/_version.py +24 -0
- brainmint/callbacks/log_images.py +341 -0
- brainmint/callbacks/modality_completion_schedule.py +315 -0
- brainmint/data/__init__.py +1 -0
- brainmint/data/brainscape.py +401 -0
- brainmint/data/brainscape_paired.py +945 -0
- brainmint/data/transforms/__init__.py +1 -0
- brainmint/data/transforms/conditioning.py +52 -0
- brainmint/data/transforms/copy_paths.py +44 -0
- brainmint/data/transforms/demographics.py +378 -0
- brainmint/data/transforms/intensity.py +50 -0
- brainmint/data/transforms/modality_choice.py +287 -0
- brainmint/data/transforms/mri_vae.py +308 -0
- brainmint/data/transforms/segmentation.py +478 -0
- brainmint/data/transforms/stream_mapping.py +226 -0
- brainmint/data/transforms/synthetic_mask.py +580 -0
- brainmint/data/utils.py +35 -0
- brainmint/external/__init__.py +20 -0
- brainmint/external/registry.py +78 -0
- brainmint/external/repo_manager.py +333 -0
- brainmint/external/sys_path.py +32 -0
- brainmint/inference/__init__.py +1 -0
- brainmint/inference/controlnet/__init__.py +5 -0
- brainmint/inference/controlnet/controlnet_inference.py +548 -0
- brainmint/inference/core/__init__.py +23 -0
- brainmint/inference/core/context.py +76 -0
- brainmint/inference/core/interfaces.py +70 -0
- brainmint/inference/core/runner.py +93 -0
- brainmint/inference/core/scheduler.py +98 -0
- brainmint/inference/diffusion/__init__.py +2 -0
- brainmint/inference/diffusion/conditioning/__init__.py +2 -0
- brainmint/inference/diffusion/conditioning/base.py +67 -0
- brainmint/inference/diffusion/conditioning/common/__init__.py +1 -0
- brainmint/inference/diffusion/conditioning/common/class_labels.py +83 -0
- brainmint/inference/diffusion/conditioning/common/constant.py +155 -0
- brainmint/inference/diffusion/conditioning/common/demographics.py +85 -0
- brainmint/inference/diffusion/conditioning/common/random_vec.py +35 -0
- brainmint/inference/diffusion/conditioning/common/vector_ops.py +33 -0
- brainmint/inference/diffusion/conditioning/ldm/__init__.py +1 -0
- brainmint/inference/diffusion/conditioning/ldm/fixed_4vec.py +41 -0
- brainmint/inference/diffusion/latent/__init__.py +2 -0
- brainmint/inference/diffusion/latent/base.py +7 -0
- brainmint/inference/diffusion/latent/common/__init__.py +1 -0
- brainmint/inference/diffusion/latent/common/fixed_shape.py +54 -0
- brainmint/inference/diffusion/latent/common/from_autoencoder_encode.py +43 -0
- brainmint/inference/diffusion/latent/common/from_batch.py +30 -0
- brainmint/inference/diffusion/latent/common/from_image_shape.py +52 -0
- brainmint/inference/diffusion/pipelines/__init__.py +1 -0
- brainmint/inference/diffusion/pipelines/base.py +121 -0
- brainmint/inference/diffusion/pipelines/generation/__init__.py +1 -0
- brainmint/inference/diffusion/pipelines/generation/ldm/__init__.py +1 -0
- brainmint/inference/diffusion/pipelines/generation/ldm/ukb_ddim.py +14 -0
- brainmint/inference/diffusion/samplers/__init__.py +2 -0
- brainmint/inference/diffusion/samplers/base.py +85 -0
- brainmint/inference/diffusion/samplers/common/cond_unet.py +69 -0
- brainmint/inference/diffusion/samplers/common/monai_unet.py +240 -0
- brainmint/inference/diffusion/samplers/ldm/__init__.py +1 -0
- brainmint/inference/diffusion/samplers/ldm/ukb_ddim.py +50 -0
- brainmint/inference/dynamic_inference.py +33 -0
- brainmint/inference/generation/__init__.py +1 -0
- brainmint/inference/generation/batch_builders/__init__.py +1 -0
- brainmint/inference/generation/batch_builders/med_ddpm.py +50 -0
- brainmint/inference/generation/pipelines/__init__.py +1 -0
- brainmint/inference/generation/pipelines/common.py +6 -0
- brainmint/inference/generation/pipelines/external.py +13 -0
- brainmint/inference/generation/pipelines/hagan.py +43 -0
- brainmint/inference/generation/pipelines/maisi.py +52 -0
- brainmint/inference/generation/pipelines/med_ddpm.py +89 -0
- brainmint/inference/generation/pipelines/wdm3d.py +58 -0
- brainmint/inference/io/__init__.py +5 -0
- brainmint/inference/io/base.py +33 -0
- brainmint/inference/io/dataset_writers.py +125 -0
- brainmint/inference/io/readers.py +56 -0
- brainmint/inference/io/writers.py +191 -0
- brainmint/inference/postprocess/__init__.py +3 -0
- brainmint/inference/postprocess/base.py +13 -0
- brainmint/inference/postprocess/brats_pipeline.py +317 -0
- brainmint/inference/postprocess/reorient.py +101 -0
- brainmint/inference/translation/__init__.py +17 -0
- brainmint/inference/translation/generators/__init__.py +12 -0
- brainmint/inference/translation/generators/aldm.py +7 -0
- brainmint/inference/translation/generators/cwdm.py +7 -0
- brainmint/inference/translation/pipelines/__init__.py +1 -0
- brainmint/integrations/__init__.py +1 -0
- brainmint/integrations/aldm/__init__.py +10 -0
- brainmint/integrations/aldm/ldm.py +47 -0
- brainmint/integrations/aldm/repo.py +132 -0
- brainmint/integrations/aldm/vqgan.py +136 -0
- brainmint/integrations/brainsynth/__init__.py +7 -0
- brainmint/integrations/brainsynth/inferer.py +91 -0
- brainmint/integrations/brainsynth/vendor_vqvae.py +569 -0
- brainmint/integrations/brasyn/__init__.py +7 -0
- brainmint/integrations/brasyn/io.py +154 -0
- brainmint/integrations/brasyn/missing_mri.py +181 -0
- brainmint/integrations/brasyn/modalities.py +71 -0
- brainmint/integrations/brasyn/runtime.py +347 -0
- brainmint/integrations/cwdm/__init__.py +8 -0
- brainmint/integrations/cwdm/repo.py +31 -0
- brainmint/integrations/cwdm/translator.py +103 -0
- brainmint/integrations/hagan/__init__.py +8 -0
- brainmint/integrations/hagan/generator.py +35 -0
- brainmint/integrations/hagan/repo.py +31 -0
- brainmint/integrations/maisi/__init__.py +8 -0
- brainmint/integrations/maisi/autoencoder.py +15 -0
- brainmint/integrations/maisi/generator.py +301 -0
- brainmint/integrations/maisi/repo.py +31 -0
- brainmint/integrations/med_ddpm/__init__.py +8 -0
- brainmint/integrations/med_ddpm/generator.py +68 -0
- brainmint/integrations/med_ddpm/repo.py +31 -0
- brainmint/integrations/wdm3d/__init__.py +8 -0
- brainmint/integrations/wdm3d/generator.py +148 -0
- brainmint/integrations/wdm3d/repo.py +31 -0
- brainmint/lightning/__init__.py +7 -0
- brainmint/lightning/controlnet_module.py +1377 -0
- brainmint/lightning/diffusion_inference_module.py +110 -0
- brainmint/lightning/diffusion_module.py +1250 -0
- brainmint/lightning/export_latents_module.py +135 -0
- brainmint/lightning/generic_inference_module.py +78 -0
- brainmint/lightning/vae_module.py +324 -0
- brainmint/losses/__init__.py +1 -0
- brainmint/losses/mask_recon_loss.py +100 -0
- brainmint/losses/utils.py +23 -0
- brainmint/losses/vae_loss_manager.py +272 -0
- brainmint/metrics/__init__.py +1 -0
- brainmint/metrics/diffusion_metrics.py +552 -0
- brainmint/metrics/reconstruction.py +462 -0
- brainmint/models/__init__.py +1 -0
- brainmint/models/blocks/__init__.py +8 -0
- brainmint/models/blocks/haar_dwt.py +205 -0
- brainmint/models/blocks/haar_wavelet_fusion.py +187 -0
- brainmint/models/compression/__init__.py +8 -0
- brainmint/models/compression/aldm_vqgan.py +37 -0
- brainmint/models/compression/brainsynth_vqvae.py +59 -0
- brainmint/models/compression/dwt.py +52 -0
- brainmint/models/compression/ldm_vae.py +61 -0
- brainmint/models/compression/maisi_vae_gan.py +53 -0
- brainmint/models/compression/wavelet_fusion.py +714 -0
- brainmint/models/compression/wavelet_vae.py +500 -0
- brainmint/models/conditioning/demographics_encoder.py +241 -0
- brainmint/models/controlnet/__init__.py +5 -0
- brainmint/models/controlnet/controlnet.py +195 -0
- brainmint/models/controlnet/controlnet_monai.py +472 -0
- brainmint/models/generation/__init__.py +8 -0
- brainmint/models/generation/diffusion_unet.py +411 -0
- brainmint/models/generation/hagan.py +106 -0
- brainmint/models/generation/maisi.py +134 -0
- brainmint/models/generation/med_ddpm.py +125 -0
- brainmint/models/generation/wdm3d.py +166 -0
- brainmint/models/schedulers/__init__.py +1 -0
- brainmint/models/schedulers/rflow_scheduler.py +23 -0
- brainmint/models/translation/__init__.py +8 -0
- brainmint/models/translation/aldm.py +347 -0
- brainmint/models/translation/brasyn.py +45 -0
- brainmint/models/translation/cwdm.py +160 -0
- brainmint/models/translation/utils.py +106 -0
- brainmint/py.typed +0 -0
- brainmint/utils/__init__.py +1 -0
- brainmint/utils/batch.py +32 -0
- brainmint/utils/ema.py +329 -0
- brainmint/utils/gpumem_utils.py +50 -0
- brainmint/utils/schedules.py +199 -0
- brainmint/utils/spatial.py +71 -0
- brainmint/utils/state_dict_loader.py +371 -0
- brainmint/visualization/__init__.py +3 -0
- brainmint/visualization/slices.py +46 -0
- brainmint-0.1.0.dist-info/METADATA +109 -0
- brainmint-0.1.0.dist-info/RECORD +171 -0
- brainmint-0.1.0.dist-info/WHEEL +5 -0
- brainmint-0.1.0.dist-info/licenses/LICENSE +21 -0
- brainmint-0.1.0.dist-info/top_level.txt +1 -0
brainmint/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""BrainMint models, integrations, and inference tools for medical image synthesis."""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import PackageNotFoundError
|
|
4
|
+
from importlib.metadata import version as _metadata_version
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from ._version import __version__
|
|
8
|
+
except (ImportError, ModuleNotFoundError):
|
|
9
|
+
try:
|
|
10
|
+
__version__ = _metadata_version("brainmint")
|
|
11
|
+
except PackageNotFoundError:
|
|
12
|
+
__version__ = "0.0.0"
|
brainmint/_version.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# file generated by vcs-versioning
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"__version__",
|
|
7
|
+
"__version_tuple__",
|
|
8
|
+
"version",
|
|
9
|
+
"version_tuple",
|
|
10
|
+
"__commit_id__",
|
|
11
|
+
"commit_id",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
version: str
|
|
15
|
+
__version__: str
|
|
16
|
+
__version_tuple__: tuple[int | str, ...]
|
|
17
|
+
version_tuple: tuple[int | str, ...]
|
|
18
|
+
commit_id: str | None
|
|
19
|
+
__commit_id__: str | None
|
|
20
|
+
|
|
21
|
+
__version__ = version = '0.1.0'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 1, 0)
|
|
23
|
+
|
|
24
|
+
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import logging
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, List, Optional, Sequence, Callable, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import contextlib
|
|
8
|
+
import pytorch_lightning as pl
|
|
9
|
+
from monai.transforms import SaveImage
|
|
10
|
+
from monai.data import MetaTensor
|
|
11
|
+
|
|
12
|
+
_LOG = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_tb_writer(trainer):
|
|
16
|
+
logs = trainer.loggers if getattr(trainer, "loggers", None) else ([trainer.logger] if getattr(trainer, "logger", None) else [])
|
|
17
|
+
for lg in logs:
|
|
18
|
+
if isinstance(lg, pl.loggers.TensorBoardLogger):
|
|
19
|
+
return lg.experiment # SummaryWriter
|
|
20
|
+
return None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SaveMRIImages(pl.Callback):
|
|
24
|
+
"""
|
|
25
|
+
After each validation epoch:
|
|
26
|
+
• iterate the val dataloader,
|
|
27
|
+
• collect ONE sample per requested modality (e.g., T1w/T2w/T1ce/FLAIR),
|
|
28
|
+
• run module inference (preferred) or forward() as fallback,
|
|
29
|
+
• save paired NIfTI files into: <run_dir>/<subdir>/epoch-XXX/.
|
|
30
|
+
|
|
31
|
+
Filenames:
|
|
32
|
+
<tag>_<MODALITY>_input.nii.gz
|
|
33
|
+
<tag>_<MODALITY>_<OUTNAME>.nii.gz (for each selected model output)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
modalities: Sequence[str] = ("T1w", "T2w", "T1ce", "FLAIR"),
|
|
39
|
+
tag: str = "val",
|
|
40
|
+
subdir: str = "val_samples",
|
|
41
|
+
dirpath: Optional[str] = None, # Defaults to trainer.log_dir
|
|
42
|
+
infer_kwarg_keys: Sequence[str] = ("image",), # Batch keys to pass to inference
|
|
43
|
+
infer_method_candidates: Sequence[str] = ("_run_inference",),
|
|
44
|
+
fallback_to_forward: bool = True,
|
|
45
|
+
output_names: Sequence[str] = ("output", "z_mu", "z_sigma"), # Names for outputs, matching the order the model returns.
|
|
46
|
+
save_outputs: Sequence[Union[int, bool, str]] = (1, 0, 0), # Output to save, Boolean Mask | Indices | Output Names
|
|
47
|
+
input_save_list: Optional[Sequence[str]] = None,
|
|
48
|
+
save_dtype: torch.dtype = torch.float32,
|
|
49
|
+
clamp_min: float = 0.0,
|
|
50
|
+
clamp_max: float = 1.0,
|
|
51
|
+
output_activation: Optional[str] = None,
|
|
52
|
+
separate_folder: bool = False,
|
|
53
|
+
dataset_module: Optional[Any] = None,
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
) -> None:
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.modalities = [m.lower() for m in modalities]
|
|
59
|
+
self.tag = tag
|
|
60
|
+
self.subdir = subdir
|
|
61
|
+
self.dirpath = dirpath
|
|
62
|
+
self.infer_method_candidates = list(infer_method_candidates)
|
|
63
|
+
self.fallback_to_forward = fallback_to_forward
|
|
64
|
+
self.infer_kwarg_keys = list(infer_kwarg_keys)
|
|
65
|
+
self.input_save_list = list(input_save_list) if input_save_list is not None else []
|
|
66
|
+
|
|
67
|
+
self.output_names = list(output_names)
|
|
68
|
+
self.save_indices = self._normalize_save_selector(save_outputs, self.output_names)
|
|
69
|
+
|
|
70
|
+
self.save_dtype = save_dtype if isinstance(save_dtype, torch.dtype) else getattr(
|
|
71
|
+
torch, str(save_dtype).strip().lower().removeprefix("torch.")
|
|
72
|
+
)
|
|
73
|
+
self.clamp_min = clamp_min
|
|
74
|
+
self.clamp_max = clamp_max
|
|
75
|
+
self.output_activation = output_activation.lower() if isinstance(output_activation, str) and output_activation else None
|
|
76
|
+
if self.output_activation not in (None, "sigmoid", "tanh"):
|
|
77
|
+
raise ValueError(f"SaveMRIImages: unsupported output_activation={output_activation}")
|
|
78
|
+
self.separate_folder = separate_folder
|
|
79
|
+
self.dataset_module = dataset_module
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def _normalize_save_selector(selector: Sequence[Union[int, bool, str]], names: Sequence[str]) -> List[int]:
|
|
83
|
+
"""Normalize save_outputs → sorted unique indices."""
|
|
84
|
+
n = len(names)
|
|
85
|
+
lower_names = {nm.lower() for nm in names}
|
|
86
|
+
name_map = {nm.lower(): i for i, nm in enumerate(names)}
|
|
87
|
+
|
|
88
|
+
is_bool_mask = (n > 0 and len(selector) == n) and all(
|
|
89
|
+
isinstance(x, bool) or (isinstance(x, int) and x in (0, 1)) for x in selector
|
|
90
|
+
)
|
|
91
|
+
if is_bool_mask:
|
|
92
|
+
idxs = [i for i, flag in enumerate(selector) if bool(flag)]
|
|
93
|
+
return sorted(set(idxs))
|
|
94
|
+
|
|
95
|
+
is_indices = len(selector) <= n and all(isinstance(x, int) and 0 <= x < n for x in selector)
|
|
96
|
+
if is_indices:
|
|
97
|
+
return sorted(set(selector))
|
|
98
|
+
|
|
99
|
+
is_names = len(selector) <= n and all(isinstance(x, str) and x.lower() in lower_names for x in selector)
|
|
100
|
+
idxs: List[int] = []
|
|
101
|
+
if is_names:
|
|
102
|
+
for s in selector:
|
|
103
|
+
j = name_map.get(s.lower())
|
|
104
|
+
if j is not None:
|
|
105
|
+
idxs.append(j)
|
|
106
|
+
|
|
107
|
+
if not idxs:
|
|
108
|
+
_LOG.error(f"SaveMRIImages Callback - Invalid output selector after filtering: {selector}")
|
|
109
|
+
return sorted(set(idxs))
|
|
110
|
+
|
|
111
|
+
def _resolve_infer_fn(self, pl_module: pl.LightningModule) -> Optional[Callable]:
|
|
112
|
+
for name in self.infer_method_candidates:
|
|
113
|
+
fn = getattr(pl_module, name, None)
|
|
114
|
+
if callable(fn):
|
|
115
|
+
_LOG.info(f"SaveMRIImages Callback - Using inference method '{name}'.")
|
|
116
|
+
return fn
|
|
117
|
+
if self.fallback_to_forward and hasattr(pl_module, "forward"):
|
|
118
|
+
_LOG.warning(
|
|
119
|
+
"SaveMRIImages Callback - None of %s found; falling back to forward().",
|
|
120
|
+
self.infer_method_candidates,
|
|
121
|
+
)
|
|
122
|
+
return pl_module.forward
|
|
123
|
+
_LOG.warning("SaveMRIImages Callback - No inference method and fallback disabled; will skip saving.")
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
def _slice_like_batch(self, val: Any, batch_size: int, i: int):
|
|
127
|
+
"""Return the i-th sample if batched along dim 0; else return as-is."""
|
|
128
|
+
if torch.is_tensor(val):
|
|
129
|
+
if val.dim() > 0 and val.size(0) == batch_size:
|
|
130
|
+
return val[i:i + 1]
|
|
131
|
+
return val
|
|
132
|
+
if isinstance(val, (list, tuple)) and len(val) == batch_size:
|
|
133
|
+
return val[i]
|
|
134
|
+
return val
|
|
135
|
+
|
|
136
|
+
def _build_infer_kwargs(
|
|
137
|
+
self,
|
|
138
|
+
pl_module: pl.LightningModule,
|
|
139
|
+
batch: Dict[str, Any],
|
|
140
|
+
batch_size: int,
|
|
141
|
+
i: int,
|
|
142
|
+
) -> Dict[str, Any]:
|
|
143
|
+
"""Build kwargs from configured keys; first entry is treated as input."""
|
|
144
|
+
if not self.infer_kwarg_keys:
|
|
145
|
+
return {}
|
|
146
|
+
kwargs: Dict[str, Any] = {}
|
|
147
|
+
for k in self.infer_kwarg_keys:
|
|
148
|
+
if k in batch:
|
|
149
|
+
v = self._slice_like_batch(batch[k], batch_size, i)
|
|
150
|
+
if torch.is_tensor(v):
|
|
151
|
+
v = v.to(pl_module.device, non_blocking=True)
|
|
152
|
+
kwargs[k] = v
|
|
153
|
+
else:
|
|
154
|
+
raise KeyError(f"Key:{k} missing in batch, Available keys:{list(batch.keys())}")
|
|
155
|
+
return {kk: vv for kk, vv in kwargs.items() if vv is not None}
|
|
156
|
+
|
|
157
|
+
@torch.no_grad()
|
|
158
|
+
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
159
|
+
# Only the main/global process writes
|
|
160
|
+
if not trainer.is_global_zero:
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
run_dir = Path(self.dirpath) if self.dirpath else Path(trainer.log_dir or trainer.default_root_dir)
|
|
164
|
+
epoch_dir = run_dir / self.subdir / f"epoch-{int(pl_module.current_epoch):03d}"
|
|
165
|
+
epoch_dir.mkdir(parents=True, exist_ok=True)
|
|
166
|
+
|
|
167
|
+
infer_fn = self._resolve_infer_fn(pl_module)
|
|
168
|
+
if infer_fn is None:
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
# Choose which datamodule/dataloader to use
|
|
172
|
+
if self.dataset_module is not None:
|
|
173
|
+
dm = self.dataset_module
|
|
174
|
+
try:
|
|
175
|
+
if hasattr(dm, "setup"): dm.setup()
|
|
176
|
+
except Exception as e:
|
|
177
|
+
_LOG.warning(f"... prepare/setup custom dataset module: {e}")
|
|
178
|
+
if dm is None:
|
|
179
|
+
_LOG.warning("SaveMRIImages Callback - No datamodule; skipping.")
|
|
180
|
+
return
|
|
181
|
+
vloader = dm.val_dataloader()
|
|
182
|
+
else:
|
|
183
|
+
dm = trainer.datamodule
|
|
184
|
+
if dm is None:
|
|
185
|
+
_LOG.warning("SaveMRIImages Callback - No datamodule; skipping.")
|
|
186
|
+
return
|
|
187
|
+
vloader = dm.val_dataloader()
|
|
188
|
+
|
|
189
|
+
want = set(self.modalities)
|
|
190
|
+
got: set[str] = set()
|
|
191
|
+
_LOG.info(
|
|
192
|
+
f"SaveMRIImages Callback - Collecting modalities {sorted(want)} for epoch {int(pl_module.current_epoch)}"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
required = set(self.infer_kwarg_keys)
|
|
196
|
+
|
|
197
|
+
for batch in vloader:
|
|
198
|
+
if want == got:
|
|
199
|
+
break
|
|
200
|
+
|
|
201
|
+
missing = required - set(batch.keys())
|
|
202
|
+
if missing:
|
|
203
|
+
raise KeyError(
|
|
204
|
+
f"SaveMRIImages: missing required keys: {', '.join(sorted(missing))}. Present: {', '.join(sorted(batch.keys()))}"
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
mods = batch.get("modality", None)
|
|
208
|
+
batch_size = int(batch[self.infer_kwarg_keys[0]].shape[0]) # Shape of First kwarg
|
|
209
|
+
missing_modality = mods is None
|
|
210
|
+
if missing_modality:
|
|
211
|
+
_LOG.warning(
|
|
212
|
+
"SaveMRIImages Callback - Batch missing 'modality'; saving first available samples."
|
|
213
|
+
)
|
|
214
|
+
mod_list = [None] * batch_size
|
|
215
|
+
else:
|
|
216
|
+
mod_list = [m.lower() for m in batch["modality"]]
|
|
217
|
+
|
|
218
|
+
remaining = len(want - got)
|
|
219
|
+
saved_without_modality = 0
|
|
220
|
+
for i in range(batch_size):
|
|
221
|
+
if want == got:
|
|
222
|
+
break
|
|
223
|
+
|
|
224
|
+
mod_i = mod_list[i]
|
|
225
|
+
if mod_i is not None:
|
|
226
|
+
mod_i = mod_i.lower()
|
|
227
|
+
if (mod_i not in want) or (mod_i in got):
|
|
228
|
+
continue
|
|
229
|
+
mod_label = mod_i
|
|
230
|
+
else:
|
|
231
|
+
if saved_without_modality >= remaining:
|
|
232
|
+
break
|
|
233
|
+
mod_label = f"sample{i:03d}"
|
|
234
|
+
|
|
235
|
+
infer_kwargs = self._build_infer_kwargs(
|
|
236
|
+
pl_module=pl_module,
|
|
237
|
+
batch=batch,
|
|
238
|
+
batch_size=batch_size,
|
|
239
|
+
i=i,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
plugin = getattr(trainer.strategy, "precision_plugin", None)
|
|
243
|
+
ctx = plugin.forward_context() if plugin else contextlib.nullcontext()
|
|
244
|
+
try:
|
|
245
|
+
with ctx, torch.no_grad():
|
|
246
|
+
result = infer_fn(infer_kwargs)
|
|
247
|
+
except Exception as e:
|
|
248
|
+
kw_types = {k: type(v).__name__ for k, v in infer_kwargs.items()}
|
|
249
|
+
raise RuntimeError(
|
|
250
|
+
f"SaveMRIImages: inference failed for modality '{mod_i}'. "
|
|
251
|
+
f"Kwarg types: {kw_types}"
|
|
252
|
+
) from e
|
|
253
|
+
|
|
254
|
+
if isinstance(result, dict):
|
|
255
|
+
outputs = [result.get(nm, None) for nm in self.output_names]
|
|
256
|
+
elif isinstance(result, (tuple, list)):
|
|
257
|
+
outputs = list(result)
|
|
258
|
+
else: # single tensor
|
|
259
|
+
outputs = [result]
|
|
260
|
+
|
|
261
|
+
# Prepare & save input once
|
|
262
|
+
save_inputs = []
|
|
263
|
+
if not self.input_save_list:
|
|
264
|
+
save_inputs.append(self.infer_kwarg_keys[0]) # Fist Key (infer_kwarg_keys) is the input
|
|
265
|
+
else:
|
|
266
|
+
save_inputs = self.input_save_list
|
|
267
|
+
|
|
268
|
+
for inp_name in save_inputs:
|
|
269
|
+
inp = infer_kwargs[inp_name]
|
|
270
|
+
if inp.shape[0] != 1:
|
|
271
|
+
raise ValueError(f"SaveMRIImages - Input dimension error, Length of 1st Dim should be 1. Got:{inp.shape} ")
|
|
272
|
+
inp_vol = inp[0].detach().to(dtype=self.save_dtype).clamp(self.clamp_min, self.clamp_max).cpu()
|
|
273
|
+
if not isinstance(inp_vol, MetaTensor):
|
|
274
|
+
inp_vol = MetaTensor(inp_vol, meta={"filename_or_obj": f"_.nii.gz"})
|
|
275
|
+
input_saver = SaveImage(
|
|
276
|
+
output_dir=str(epoch_dir),
|
|
277
|
+
output_postfix=f"{self.tag}_{mod_label}_input_{inp_name}",
|
|
278
|
+
output_ext=".nii.gz",
|
|
279
|
+
separate_folder=self.separate_folder,
|
|
280
|
+
)
|
|
281
|
+
input_saver(inp_vol)
|
|
282
|
+
|
|
283
|
+
del input_saver, inp_vol
|
|
284
|
+
gc.collect()
|
|
285
|
+
|
|
286
|
+
for j in self.save_indices:
|
|
287
|
+
if j >= len(outputs) or outputs[j] is None:
|
|
288
|
+
continue
|
|
289
|
+
out = outputs[j]
|
|
290
|
+
out_vol = out[0].detach()
|
|
291
|
+
if self.output_activation == "sigmoid":
|
|
292
|
+
out_vol = torch.sigmoid(out_vol)
|
|
293
|
+
elif self.output_activation == "tanh":
|
|
294
|
+
out_vol = torch.tanh(out_vol)
|
|
295
|
+
out_vol = out_vol.to(self.save_dtype).clamp(self.clamp_min, self.clamp_max).cpu()
|
|
296
|
+
|
|
297
|
+
if isinstance(out_vol, MetaTensor):
|
|
298
|
+
out_meta_vol = out_vol
|
|
299
|
+
else:
|
|
300
|
+
|
|
301
|
+
#TODO: FIX PASSING OF INPUT META INFO
|
|
302
|
+
out_meta = {}#getattr(inp_vol, "meta", None) or {}
|
|
303
|
+
out_meta_vol = MetaTensor(out_vol, meta=out_meta)
|
|
304
|
+
|
|
305
|
+
out_name = self.output_names[j] if j < len(self.output_names) else f"out{j}"
|
|
306
|
+
saver = SaveImage(
|
|
307
|
+
output_dir=str(epoch_dir),
|
|
308
|
+
output_postfix=f"{self.tag}_{mod_label}_output_{out_name}",
|
|
309
|
+
output_ext=".nii.gz",
|
|
310
|
+
separate_folder=self.separate_folder,
|
|
311
|
+
)
|
|
312
|
+
saver(out_meta_vol)
|
|
313
|
+
|
|
314
|
+
del saver, out_meta_vol
|
|
315
|
+
gc.collect()
|
|
316
|
+
|
|
317
|
+
if mod_i is None:
|
|
318
|
+
saved_without_modality += 1
|
|
319
|
+
if saved_without_modality >= remaining:
|
|
320
|
+
got = set(want)
|
|
321
|
+
else:
|
|
322
|
+
got.add(mod_i)
|
|
323
|
+
_LOG.info(
|
|
324
|
+
f"SaveMRIImages Callback - Saved {mod_label} (input & {','.join(self.output_names[k] for k in self.save_indices if k < len(self.output_names))}) → {epoch_dir}"
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
del infer_kwargs, outputs
|
|
328
|
+
gc.collect()
|
|
329
|
+
|
|
330
|
+
missing = sorted(want - got)
|
|
331
|
+
if missing:
|
|
332
|
+
_LOG.warning(f"SaveMRIImages Callback - Missing modalities this epoch: {missing}")
|
|
333
|
+
else:
|
|
334
|
+
_LOG.info(f"SaveMRIImages Callback - Saved all requested modalities: {sorted(got)}")
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
# After saving, if we used a custom dataset, tear it down and free memory
|
|
338
|
+
if self.dataset_module is not None:
|
|
339
|
+
del vloader, dm
|
|
340
|
+
gc.collect()
|
|
341
|
+
torch.cuda.empty_cache()
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence
|
|
5
|
+
|
|
6
|
+
import pytorch_lightning as pl
|
|
7
|
+
|
|
8
|
+
from brainmint.utils.schedules import PiecewiseSchedule
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _normalize_probs(probs: Mapping[str, float]) -> Dict[str, float]:
|
|
14
|
+
out = {str(k): float(v) for k, v in dict(probs).items()}
|
|
15
|
+
s = float(sum(max(0.0, v) for v in out.values()))
|
|
16
|
+
if s <= 0.0:
|
|
17
|
+
raise ValueError(f"Cannot normalize probabilities with non-positive sum: {out}")
|
|
18
|
+
return {k: max(0.0, v) / s for k, v in out.items()}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _iter_transforms(root: Any) -> Iterable[Any]:
|
|
22
|
+
"""Depth-first walk over transform graphs (matches tests/data/test_translation_configs.py)."""
|
|
23
|
+
if root is None:
|
|
24
|
+
return
|
|
25
|
+
stack = [root]
|
|
26
|
+
seen = set()
|
|
27
|
+
while stack:
|
|
28
|
+
current = stack.pop()
|
|
29
|
+
if id(current) in seen:
|
|
30
|
+
continue
|
|
31
|
+
seen.add(id(current))
|
|
32
|
+
yield current
|
|
33
|
+
|
|
34
|
+
children: List[Any] = []
|
|
35
|
+
nested = getattr(current, "_transform", None)
|
|
36
|
+
if nested is not None and nested is not current:
|
|
37
|
+
children.append(nested)
|
|
38
|
+
|
|
39
|
+
transforms = getattr(current, "transforms", None)
|
|
40
|
+
if transforms:
|
|
41
|
+
children.extend(list(transforms))
|
|
42
|
+
|
|
43
|
+
extra_start = getattr(current, "extra_xforms_start", None)
|
|
44
|
+
if extra_start:
|
|
45
|
+
children.extend(list(extra_start))
|
|
46
|
+
|
|
47
|
+
extra_end = getattr(current, "extra_xforms_end", None)
|
|
48
|
+
if extra_end:
|
|
49
|
+
children.extend(list(extra_end))
|
|
50
|
+
|
|
51
|
+
stack.extend(reversed(children))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _find_choice_states(dm: Any) -> List[Any]:
|
|
55
|
+
st: List[Any] = []
|
|
56
|
+
tf = getattr(dm, "train_tf", None)
|
|
57
|
+
root = getattr(tf, "transform", None) or tf
|
|
58
|
+
for t in _iter_transforms(root):
|
|
59
|
+
state = getattr(t, "state", None)
|
|
60
|
+
if state is not None and hasattr(state, "get_choices") and hasattr(state, "set_choices"):
|
|
61
|
+
st.append(state)
|
|
62
|
+
return st
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _find_sampling_states(dm: Any) -> List[Any]:
|
|
66
|
+
st: List[Any] = []
|
|
67
|
+
tf = getattr(dm, "train_tf", None)
|
|
68
|
+
root = getattr(tf, "transform", None) or tf
|
|
69
|
+
for t in _iter_transforms(root):
|
|
70
|
+
ss = getattr(t, "sampling_state", None)
|
|
71
|
+
if ss is not None and hasattr(ss, "get_config") and hasattr(ss, "set_config"):
|
|
72
|
+
st.append(ss)
|
|
73
|
+
return st
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class ModalityCompletionScheduleCallback(pl.Callback):
|
|
77
|
+
"""
|
|
78
|
+
Generic, config-driven epoch-boundary scheduler for modality completion experiments.
|
|
79
|
+
|
|
80
|
+
What it can update:
|
|
81
|
+
1) Bucket sampler probabilities via DataModule.set_bucket_probs(...)
|
|
82
|
+
2) Stream selection probabilities via SharedChoiceState.set_choices(...)
|
|
83
|
+
3) Partial sampling config via SharedSamplingState.set_config(...)
|
|
84
|
+
|
|
85
|
+
All behavior is driven by YAML config (step/linear schedules). No stage hardcoding.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
*,
|
|
91
|
+
bucket_schedule: Optional[Sequence[Mapping[str, Any]]] = None,
|
|
92
|
+
choice_schedule: Optional[Sequence[Mapping[str, Any]]] = None,
|
|
93
|
+
sampling_schedule: Optional[Sequence[Mapping[str, Any]]] = None,
|
|
94
|
+
update_on: str = "epoch_start",
|
|
95
|
+
strict: bool = True,
|
|
96
|
+
normalize_bucket_probs: bool = True,
|
|
97
|
+
normalize_choice_probs: bool = True,
|
|
98
|
+
auto_complement_two_way: bool = True,
|
|
99
|
+
log_updates: bool = True,
|
|
100
|
+
) -> None:
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.update_on = str(update_on).lower().strip()
|
|
103
|
+
if self.update_on not in ("epoch_start", "epoch_end"):
|
|
104
|
+
raise ValueError("update_on must be one of: 'epoch_start', 'epoch_end'")
|
|
105
|
+
|
|
106
|
+
self.strict = bool(strict)
|
|
107
|
+
self.normalize_bucket_probs = bool(normalize_bucket_probs)
|
|
108
|
+
self.normalize_choice_probs = bool(normalize_choice_probs)
|
|
109
|
+
self.auto_complement_two_way = bool(auto_complement_two_way)
|
|
110
|
+
self.log_updates = bool(log_updates)
|
|
111
|
+
|
|
112
|
+
self.bucket_sched = PiecewiseSchedule(bucket_schedule, name="bucket_schedule")
|
|
113
|
+
self.choice_sched = PiecewiseSchedule(choice_schedule, name="choice_schedule")
|
|
114
|
+
self.sampling_sched = PiecewiseSchedule(sampling_schedule, name="sampling_schedule")
|
|
115
|
+
|
|
116
|
+
self._choice_states: List[Any] = []
|
|
117
|
+
self._sampling_states: List[Any] = []
|
|
118
|
+
|
|
119
|
+
def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: Optional[str] = None) -> None:
|
|
120
|
+
if stage not in (None, "fit"):
|
|
121
|
+
return
|
|
122
|
+
dm = trainer.datamodule
|
|
123
|
+
self._choice_states = _find_choice_states(dm)
|
|
124
|
+
self._sampling_states = _find_sampling_states(dm)
|
|
125
|
+
|
|
126
|
+
if self.strict:
|
|
127
|
+
if (self.choice_sched.steps or self.choice_sched.lines) and not self._choice_states:
|
|
128
|
+
raise RuntimeError("choice_schedule provided but no SharedChoiceState found in train transforms.")
|
|
129
|
+
if (self.sampling_sched.steps or self.sampling_sched.lines) and not self._sampling_states:
|
|
130
|
+
raise RuntimeError("sampling_schedule provided but no SharedSamplingState found in train transforms.")
|
|
131
|
+
|
|
132
|
+
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
133
|
+
if self.update_on == "epoch_start":
|
|
134
|
+
self._apply(trainer)
|
|
135
|
+
|
|
136
|
+
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
137
|
+
if self.update_on == "epoch_end":
|
|
138
|
+
self._apply(trainer)
|
|
139
|
+
|
|
140
|
+
def _apply(self, trainer: pl.Trainer) -> None:
|
|
141
|
+
epoch = int(trainer.current_epoch)
|
|
142
|
+
dm = trainer.datamodule
|
|
143
|
+
|
|
144
|
+
bucket_probs = self.bucket_sched.value_at(epoch)
|
|
145
|
+
choice_updates = self.choice_sched.value_at(epoch)
|
|
146
|
+
sampling_updates = self.sampling_sched.value_at(epoch)
|
|
147
|
+
|
|
148
|
+
if bucket_probs is not None:
|
|
149
|
+
self._apply_bucket_probs(dm, bucket_probs, epoch)
|
|
150
|
+
|
|
151
|
+
if choice_updates is not None:
|
|
152
|
+
self._apply_choice_probs(choice_updates, epoch)
|
|
153
|
+
|
|
154
|
+
if sampling_updates is not None:
|
|
155
|
+
self._apply_sampling_config(sampling_updates)
|
|
156
|
+
|
|
157
|
+
# ---------------------------
|
|
158
|
+
# Bucket probs
|
|
159
|
+
# ---------------------------
|
|
160
|
+
def _apply_bucket_probs(self, dm: Any, value: Any, epoch: int) -> None:
|
|
161
|
+
if not isinstance(value, Mapping):
|
|
162
|
+
raise TypeError(f"bucket_schedule value must be a mapping bucket->prob, got {type(value)}")
|
|
163
|
+
|
|
164
|
+
probs = {str(k): float(v) for k, v in dict(value).items()}
|
|
165
|
+
if self.normalize_bucket_probs:
|
|
166
|
+
probs = _normalize_probs(probs)
|
|
167
|
+
|
|
168
|
+
setter = getattr(dm, "set_bucket_probs", None)
|
|
169
|
+
if not callable(setter):
|
|
170
|
+
msg = "DataModule has no set_bucket_probs(); cannot apply bucket_schedule."
|
|
171
|
+
if self.strict:
|
|
172
|
+
raise RuntimeError(msg)
|
|
173
|
+
logger.warning(msg)
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
setter(probs)
|
|
177
|
+
|
|
178
|
+
if self.log_updates:
|
|
179
|
+
logger.info("[Schedule] bucket_probs(epoch=%d): %s", int(epoch), probs)
|
|
180
|
+
|
|
181
|
+
# ---------------------------
|
|
182
|
+
# Choice probs (SharedChoiceState)
|
|
183
|
+
# ---------------------------
|
|
184
|
+
def _apply_choice_probs(self, value: Any, epoch: int) -> None:
|
|
185
|
+
if not isinstance(value, Mapping):
|
|
186
|
+
raise TypeError(f"choice_schedule value must be a nested mapping, got {type(value)}")
|
|
187
|
+
|
|
188
|
+
updates: Dict[str, Any] = dict(value)
|
|
189
|
+
|
|
190
|
+
for state in self._choice_states:
|
|
191
|
+
cur = dict(state.get_choices())
|
|
192
|
+
changed = False
|
|
193
|
+
|
|
194
|
+
for bucket, mods in updates.items():
|
|
195
|
+
bucket = str(bucket)
|
|
196
|
+
if bucket not in cur:
|
|
197
|
+
if self.strict:
|
|
198
|
+
raise KeyError(f"choice_schedule refers to unknown bucket '{bucket}'")
|
|
199
|
+
logger.warning("choice_schedule: unknown bucket '%s' (ignored)", bucket)
|
|
200
|
+
continue
|
|
201
|
+
if not isinstance(mods, Mapping):
|
|
202
|
+
if self.strict:
|
|
203
|
+
raise TypeError(f"choice_schedule[{bucket}] must be a mapping modality->alias_probs")
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
bcfg = cur[bucket]
|
|
207
|
+
if not isinstance(bcfg, Mapping):
|
|
208
|
+
if self.strict:
|
|
209
|
+
raise TypeError(f"choices[{bucket}] must be mapping")
|
|
210
|
+
continue
|
|
211
|
+
|
|
212
|
+
for mod, alias_probs in mods.items():
|
|
213
|
+
mod_key = str(mod).lower()
|
|
214
|
+
if not isinstance(alias_probs, Mapping):
|
|
215
|
+
if self.strict:
|
|
216
|
+
raise TypeError(f"choice_schedule[{bucket}][{mod_key}] must be mapping alias->prob")
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
# Find modality config; fall back to "*"
|
|
220
|
+
mcfg = bcfg.get(mod_key, None)
|
|
221
|
+
wildcard = False
|
|
222
|
+
if mcfg is None:
|
|
223
|
+
mcfg = bcfg.get("*", None)
|
|
224
|
+
wildcard = True
|
|
225
|
+
if mcfg is None:
|
|
226
|
+
if self.strict:
|
|
227
|
+
raise KeyError(f"choices missing config for bucket='{bucket}' modality='{mod_key}' (and no '*')")
|
|
228
|
+
logger.warning("choices missing config for bucket='%s' modality='%s' (ignored)", bucket, mod_key)
|
|
229
|
+
continue
|
|
230
|
+
if not isinstance(mcfg, Mapping):
|
|
231
|
+
if self.strict:
|
|
232
|
+
raise TypeError(f"choices[{bucket}][{mod_key if not wildcard else '*'}] must be mapping")
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
streams = dict(mcfg.get("streams") or {})
|
|
236
|
+
if not streams:
|
|
237
|
+
if self.strict:
|
|
238
|
+
raise KeyError(f"choices[{bucket}][{mod_key if not wildcard else '*'}].streams is empty")
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
probs = dict(mcfg.get("probs") or {})
|
|
242
|
+
|
|
243
|
+
# apply updates (set)
|
|
244
|
+
for alias, p in alias_probs.items():
|
|
245
|
+
probs[str(alias)] = float(p)
|
|
246
|
+
|
|
247
|
+
if self.strict:
|
|
248
|
+
unknown_aliases = set(str(a) for a in alias_probs.keys()) - set(streams.keys())
|
|
249
|
+
if unknown_aliases:
|
|
250
|
+
raise KeyError(
|
|
251
|
+
f"choice_schedule[{bucket}][{mod_key}] refers to unknown stream aliases: "
|
|
252
|
+
f"{sorted(unknown_aliases)}"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Optional complement for 2-way choices when only one alias was specified
|
|
256
|
+
if self.auto_complement_two_way:
|
|
257
|
+
defined_aliases = list(streams.keys())
|
|
258
|
+
specified = [str(a) for a in alias_probs.keys() if str(a) in defined_aliases]
|
|
259
|
+
if len(defined_aliases) == 2 and len(specified) == 1:
|
|
260
|
+
a0 = specified[0]
|
|
261
|
+
other = defined_aliases[0] if defined_aliases[1] == a0 else defined_aliases[1]
|
|
262
|
+
p0 = float(probs.get(a0, 0.0))
|
|
263
|
+
probs[other] = max(0.0, 1.0 - p0)
|
|
264
|
+
|
|
265
|
+
# Keep only aliases defined in streams
|
|
266
|
+
probs = {a: float(probs.get(a, 0.0)) for a in streams.keys()}
|
|
267
|
+
|
|
268
|
+
if self.normalize_choice_probs:
|
|
269
|
+
probs = _normalize_probs(probs)
|
|
270
|
+
|
|
271
|
+
new_mcfg = dict(mcfg)
|
|
272
|
+
new_mcfg["probs"] = probs
|
|
273
|
+
|
|
274
|
+
new_bucket_cfg = dict(cur[bucket])
|
|
275
|
+
if wildcard and "*" in new_bucket_cfg:
|
|
276
|
+
new_bucket_cfg["*"] = new_mcfg
|
|
277
|
+
else:
|
|
278
|
+
new_bucket_cfg[mod_key] = new_mcfg
|
|
279
|
+
cur[bucket] = new_bucket_cfg
|
|
280
|
+
|
|
281
|
+
changed = True
|
|
282
|
+
|
|
283
|
+
if changed:
|
|
284
|
+
try:
|
|
285
|
+
state.set_epoch(int(epoch))
|
|
286
|
+
except Exception:
|
|
287
|
+
pass
|
|
288
|
+
state.set_choices(cur)
|
|
289
|
+
if self.log_updates:
|
|
290
|
+
logger.info("[Schedule] choice_probs(epoch=%d): updated SharedChoiceState", epoch)
|
|
291
|
+
|
|
292
|
+
# ---------------------------
|
|
293
|
+
# Sampling config (SharedSamplingState)
|
|
294
|
+
# ---------------------------
|
|
295
|
+
def _apply_sampling_config(self, value: Any) -> None:
|
|
296
|
+
if not isinstance(value, Mapping):
|
|
297
|
+
raise TypeError(f"sampling_schedule value must be a mapping, got {type(value)}")
|
|
298
|
+
upd = {str(k): v for k, v in dict(value).items()}
|
|
299
|
+
if self.strict:
|
|
300
|
+
if "sigma_prob" in upd:
|
|
301
|
+
upd["sigma_prob"] = float(upd["sigma_prob"])
|
|
302
|
+
if not (0.0 <= upd["sigma_prob"] <= 1.0):
|
|
303
|
+
raise ValueError("sampling_schedule sigma_prob must be between 0 and 1.")
|
|
304
|
+
if "sigma_alpha" in upd:
|
|
305
|
+
upd["sigma_alpha"] = float(upd["sigma_alpha"])
|
|
306
|
+
if upd["sigma_alpha"] < 0.0:
|
|
307
|
+
raise ValueError("sampling_schedule sigma_alpha must be >= 0.")
|
|
308
|
+
|
|
309
|
+
for ss in self._sampling_states:
|
|
310
|
+
cur = dict(ss.get_config())
|
|
311
|
+
cur.update(upd)
|
|
312
|
+
ss.set_config(cur)
|
|
313
|
+
|
|
314
|
+
if self.log_updates:
|
|
315
|
+
logger.info("[Schedule] sampling_config: %s", upd)
|