nshtrainer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +64 -0
- nshtrainer/_experimental/__init__.py +2 -0
- nshtrainer/_experimental/flops/__init__.py +48 -0
- nshtrainer/_experimental/flops/flop_counter.py +787 -0
- nshtrainer/_experimental/flops/module_tracker.py +140 -0
- nshtrainer/_snoop.py +216 -0
- nshtrainer/_submit/print_environment_info.py +31 -0
- nshtrainer/_submit/session/_output.py +12 -0
- nshtrainer/_submit/session/_script.py +109 -0
- nshtrainer/_submit/session/lsf.py +467 -0
- nshtrainer/_submit/session/slurm.py +573 -0
- nshtrainer/_submit/session/unified.py +350 -0
- nshtrainer/actsave/__init__.py +7 -0
- nshtrainer/actsave/_callback.py +75 -0
- nshtrainer/actsave/_loader.py +144 -0
- nshtrainer/actsave/_saver.py +337 -0
- nshtrainer/callbacks/__init__.py +35 -0
- nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
- nshtrainer/callbacks/base.py +113 -0
- nshtrainer/callbacks/early_stopping.py +112 -0
- nshtrainer/callbacks/ema.py +383 -0
- nshtrainer/callbacks/finite_checks.py +75 -0
- nshtrainer/callbacks/gradient_skipping.py +103 -0
- nshtrainer/callbacks/interval.py +322 -0
- nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
- nshtrainer/callbacks/log_epoch.py +35 -0
- nshtrainer/callbacks/norm_logging.py +187 -0
- nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
- nshtrainer/callbacks/print_table.py +90 -0
- nshtrainer/callbacks/throughput_monitor.py +56 -0
- nshtrainer/callbacks/timer.py +157 -0
- nshtrainer/callbacks/wandb_watch.py +103 -0
- nshtrainer/config.py +289 -0
- nshtrainer/data/__init__.py +4 -0
- nshtrainer/data/balanced_batch_sampler.py +132 -0
- nshtrainer/data/transform.py +67 -0
- nshtrainer/lr_scheduler/__init__.py +18 -0
- nshtrainer/lr_scheduler/_base.py +101 -0
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
- nshtrainer/model/__init__.py +44 -0
- nshtrainer/model/base.py +641 -0
- nshtrainer/model/config.py +2064 -0
- nshtrainer/model/modules/callback.py +157 -0
- nshtrainer/model/modules/debug.py +42 -0
- nshtrainer/model/modules/distributed.py +70 -0
- nshtrainer/model/modules/logger.py +170 -0
- nshtrainer/model/modules/profiler.py +24 -0
- nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
- nshtrainer/model/modules/shared_parameters.py +72 -0
- nshtrainer/nn/__init__.py +19 -0
- nshtrainer/nn/mlp.py +106 -0
- nshtrainer/nn/module_dict.py +66 -0
- nshtrainer/nn/module_list.py +50 -0
- nshtrainer/nn/nonlinearity.py +157 -0
- nshtrainer/optimizer.py +62 -0
- nshtrainer/runner.py +21 -0
- nshtrainer/scripts/check_env.py +41 -0
- nshtrainer/scripts/find_packages.py +51 -0
- nshtrainer/trainer/__init__.py +1 -0
- nshtrainer/trainer/signal_connector.py +208 -0
- nshtrainer/trainer/trainer.py +340 -0
- nshtrainer/typecheck.py +144 -0
- nshtrainer/util/environment.py +119 -0
- nshtrainer/util/seed.py +11 -0
- nshtrainer/util/singleton.py +89 -0
- nshtrainer/util/slurm.py +49 -0
- nshtrainer/util/typed.py +2 -0
- nshtrainer/util/typing_utils.py +19 -0
- nshtrainer-0.1.0.dist-info/METADATA +18 -0
- nshtrainer-0.1.0.dist-info/RECORD +72 -0
- nshtrainer-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import fnmatch
|
|
3
|
+
import importlib.util
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
9
|
+
from lightning.pytorch.callbacks import Callback
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from .base import CallbackConfigBase
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PrintTableMetricsCallback(Callback):
|
|
18
|
+
"""Prints a table with the metrics in columns on every epoch end."""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
metric_patterns: list[str] | None = None,
|
|
23
|
+
) -> None:
|
|
24
|
+
self.metrics: list = []
|
|
25
|
+
self.rich_available = importlib.util.find_spec("rich") is not None
|
|
26
|
+
self.metric_patterns = metric_patterns
|
|
27
|
+
|
|
28
|
+
if not self.rich_available:
|
|
29
|
+
log.warning(
|
|
30
|
+
"rich is not installed. Please install it to use PrintTableMetricsCallback."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
35
|
+
if not self.rich_available:
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
metrics_dict = copy.copy(trainer.callback_metrics)
|
|
39
|
+
# Filter metrics based on the patterns
|
|
40
|
+
if self.metric_patterns is not None:
|
|
41
|
+
metrics_dict = {
|
|
42
|
+
key: value
|
|
43
|
+
for key, value in metrics_dict.items()
|
|
44
|
+
if any(
|
|
45
|
+
fnmatch.fnmatch(key, pattern) for pattern in self.metric_patterns
|
|
46
|
+
)
|
|
47
|
+
}
|
|
48
|
+
self.metrics.append(metrics_dict)
|
|
49
|
+
|
|
50
|
+
from rich.console import Console
|
|
51
|
+
|
|
52
|
+
console = Console()
|
|
53
|
+
table = self.create_metrics_table()
|
|
54
|
+
console.print(table)
|
|
55
|
+
|
|
56
|
+
def create_metrics_table(self):
|
|
57
|
+
from rich.table import Table
|
|
58
|
+
|
|
59
|
+
table = Table(show_header=True, header_style="bold magenta")
|
|
60
|
+
|
|
61
|
+
# Add columns to the table based on the keys in the first metrics dictionary
|
|
62
|
+
for key in self.metrics[0].keys():
|
|
63
|
+
table.add_column(key)
|
|
64
|
+
|
|
65
|
+
# Add rows to the table based on the metrics dictionaries
|
|
66
|
+
for metric_dict in self.metrics:
|
|
67
|
+
values: list[str] = []
|
|
68
|
+
for value in metric_dict.values():
|
|
69
|
+
if torch.is_tensor(value):
|
|
70
|
+
value = float(value.item())
|
|
71
|
+
values.append(str(value))
|
|
72
|
+
table.add_row(*values)
|
|
73
|
+
|
|
74
|
+
return table
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class PrintTableMetricsConfig(CallbackConfigBase):
|
|
78
|
+
"""Configuration class for PrintTableMetricsCallback."""
|
|
79
|
+
|
|
80
|
+
name: Literal["print_table_metrics"] = "print_table_metrics"
|
|
81
|
+
|
|
82
|
+
enabled: bool = True
|
|
83
|
+
"""Whether to enable the callback or not."""
|
|
84
|
+
|
|
85
|
+
metric_patterns: list[str] | None = None
|
|
86
|
+
"""List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
|
|
87
|
+
|
|
88
|
+
@override
|
|
89
|
+
def construct_callbacks(self, root_config):
|
|
90
|
+
yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from logging import getLogger
|
|
2
|
+
from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
|
|
3
|
+
|
|
4
|
+
from typing_extensions import NotRequired, override
|
|
5
|
+
|
|
6
|
+
from ._throughput_monitor_callback import ThroughputMonitor as _ThroughputMonitor
|
|
7
|
+
from .base import CallbackConfigBase
|
|
8
|
+
|
|
9
|
+
log = getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ThroughputMonitorBatchStats(TypedDict):
|
|
13
|
+
batch_size: int
|
|
14
|
+
length: NotRequired[int | None]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@runtime_checkable
|
|
18
|
+
class SupportsThroughputMonitorModuleProtocol(Protocol):
|
|
19
|
+
def throughput_monitor_batch_stats(
|
|
20
|
+
self, batch: Any
|
|
21
|
+
) -> ThroughputMonitorBatchStats: ...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ThroughputMonitor(_ThroughputMonitor):
|
|
25
|
+
def __init__(self, window_size: int = 100) -> None:
|
|
26
|
+
super().__init__(cast(Any, None), cast(Any, None), window_size=window_size)
|
|
27
|
+
|
|
28
|
+
@override
|
|
29
|
+
def setup(self, trainer, pl_module, stage):
|
|
30
|
+
if not isinstance(pl_module, SupportsThroughputMonitorModuleProtocol):
|
|
31
|
+
raise RuntimeError(
|
|
32
|
+
"The model does not implement `throughput_monitor_batch_stats`. "
|
|
33
|
+
"Please either implement this method, or do not use the `ThroughputMonitor` callback."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def batch_size_fn(batch):
|
|
37
|
+
return pl_module.throughput_monitor_batch_stats(batch)["batch_size"]
|
|
38
|
+
|
|
39
|
+
def length_fn(batch):
|
|
40
|
+
return pl_module.throughput_monitor_batch_stats(batch).get("length")
|
|
41
|
+
|
|
42
|
+
self.batch_size_fn = batch_size_fn
|
|
43
|
+
self.length_fn = length_fn
|
|
44
|
+
|
|
45
|
+
return super().setup(trainer, pl_module, stage)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ThroughputMonitorConfig(CallbackConfigBase):
|
|
49
|
+
name: Literal["throughput_monitor"] = "throughput_monitor"
|
|
50
|
+
|
|
51
|
+
window_size: int = 100
|
|
52
|
+
"""Number of batches to use for a rolling average."""
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
def construct_callbacks(self, root_config):
|
|
56
|
+
yield ThroughputMonitor(window_size=self.window_size)
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
+
from lightning.pytorch.callbacks import Callback
|
|
7
|
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from .base import CallbackConfigBase
|
|
11
|
+
|
|
12
|
+
log = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EpochTimer(Callback):
|
|
16
|
+
def __init__(self):
|
|
17
|
+
super().__init__()
|
|
18
|
+
|
|
19
|
+
self._start_time: dict[str, float] = {}
|
|
20
|
+
self._elapsed_time: dict[str, float] = {}
|
|
21
|
+
self._total_batches: dict[str, int] = {}
|
|
22
|
+
|
|
23
|
+
@override
|
|
24
|
+
def on_train_epoch_start(
|
|
25
|
+
self, trainer: "Trainer", pl_module: "LightningModule"
|
|
26
|
+
) -> None:
|
|
27
|
+
self._start_time["train"] = time.monotonic()
|
|
28
|
+
self._total_batches["train"] = 0
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
def on_train_batch_end(
|
|
32
|
+
self,
|
|
33
|
+
trainer: "Trainer",
|
|
34
|
+
pl_module: "LightningModule",
|
|
35
|
+
outputs: STEP_OUTPUT,
|
|
36
|
+
batch: Any,
|
|
37
|
+
batch_idx: int,
|
|
38
|
+
) -> None:
|
|
39
|
+
self._total_batches["train"] += 1
|
|
40
|
+
|
|
41
|
+
@override
|
|
42
|
+
def on_train_epoch_end(
|
|
43
|
+
self, trainer: "Trainer", pl_module: "LightningModule"
|
|
44
|
+
) -> None:
|
|
45
|
+
self._elapsed_time["train"] = time.monotonic() - self._start_time["train"]
|
|
46
|
+
if trainer.is_global_zero:
|
|
47
|
+
self._log_epoch_info("train")
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def on_validation_epoch_start(
|
|
51
|
+
self, trainer: "Trainer", pl_module: "LightningModule"
|
|
52
|
+
) -> None:
|
|
53
|
+
self._start_time["val"] = time.monotonic()
|
|
54
|
+
self._total_batches["val"] = 0
|
|
55
|
+
|
|
56
|
+
@override
|
|
57
|
+
def on_validation_batch_end(
|
|
58
|
+
self,
|
|
59
|
+
trainer: "Trainer",
|
|
60
|
+
pl_module: "LightningModule",
|
|
61
|
+
outputs: STEP_OUTPUT,
|
|
62
|
+
batch: Any,
|
|
63
|
+
batch_idx: int,
|
|
64
|
+
dataloader_idx: int = 0,
|
|
65
|
+
) -> None:
|
|
66
|
+
self._total_batches["val"] += 1
|
|
67
|
+
|
|
68
|
+
@override
|
|
69
|
+
def on_validation_epoch_end(
|
|
70
|
+
self, trainer: "Trainer", pl_module: "LightningModule"
|
|
71
|
+
) -> None:
|
|
72
|
+
self._elapsed_time["val"] = time.monotonic() - self._start_time["val"]
|
|
73
|
+
if trainer.is_global_zero:
|
|
74
|
+
self._log_epoch_info("val")
|
|
75
|
+
|
|
76
|
+
@override
|
|
77
|
+
def on_test_epoch_start(
|
|
78
|
+
self, trainer: "Trainer", pl_module: "LightningModule"
|
|
79
|
+
) -> None:
|
|
80
|
+
self._start_time["test"] = time.monotonic()
|
|
81
|
+
self._total_batches["test"] = 0
|
|
82
|
+
|
|
83
|
+
@override
|
|
84
|
+
def on_test_batch_end(
|
|
85
|
+
self,
|
|
86
|
+
trainer: "Trainer",
|
|
87
|
+
pl_module: "LightningModule",
|
|
88
|
+
outputs: STEP_OUTPUT,
|
|
89
|
+
batch: Any,
|
|
90
|
+
batch_idx: int,
|
|
91
|
+
dataloader_idx: int = 0,
|
|
92
|
+
) -> None:
|
|
93
|
+
self._total_batches["test"] += 1
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
def on_test_epoch_end(
|
|
97
|
+
self, trainer: "Trainer", pl_module: "LightningModule"
|
|
98
|
+
) -> None:
|
|
99
|
+
self._elapsed_time["test"] = time.monotonic() - self._start_time["test"]
|
|
100
|
+
if trainer.is_global_zero:
|
|
101
|
+
self._log_epoch_info("test")
|
|
102
|
+
|
|
103
|
+
@override
|
|
104
|
+
def on_predict_epoch_start(
|
|
105
|
+
self, trainer: "Trainer", pl_module: "LightningModule"
|
|
106
|
+
) -> None:
|
|
107
|
+
self._start_time["predict"] = time.monotonic()
|
|
108
|
+
self._total_batches["predict"] = 0
|
|
109
|
+
|
|
110
|
+
@override
|
|
111
|
+
def on_predict_batch_end(
|
|
112
|
+
self,
|
|
113
|
+
trainer: "Trainer",
|
|
114
|
+
pl_module: "LightningModule",
|
|
115
|
+
outputs: STEP_OUTPUT,
|
|
116
|
+
batch: Any,
|
|
117
|
+
batch_idx: int,
|
|
118
|
+
dataloader_idx: int = 0,
|
|
119
|
+
) -> None:
|
|
120
|
+
self._total_batches["predict"] += 1
|
|
121
|
+
|
|
122
|
+
@override
|
|
123
|
+
def on_predict_epoch_end(
|
|
124
|
+
self, trainer: "Trainer", pl_module: "LightningModule"
|
|
125
|
+
) -> None:
|
|
126
|
+
self._elapsed_time["predict"] = time.monotonic() - self._start_time["predict"]
|
|
127
|
+
if trainer.is_global_zero:
|
|
128
|
+
self._log_epoch_info("predict")
|
|
129
|
+
|
|
130
|
+
def _log_epoch_info(self, stage: str) -> None:
|
|
131
|
+
if (elapsed_time := self._elapsed_time.get(stage)) is None:
|
|
132
|
+
return
|
|
133
|
+
total_batches = self._total_batches[stage]
|
|
134
|
+
log.critical(
|
|
135
|
+
f"Epoch {stage.capitalize()} Summary: Elapsed Time: {elapsed_time:.2f} seconds | "
|
|
136
|
+
f"Total Batches: {total_batches}"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
@override
|
|
140
|
+
def state_dict(self) -> dict[str, Any]:
|
|
141
|
+
return {
|
|
142
|
+
"elapsed_time": self._elapsed_time,
|
|
143
|
+
"total_batches": self._total_batches,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
@override
|
|
147
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
148
|
+
self._elapsed_time = state_dict["elapsed_time"]
|
|
149
|
+
self._total_batches = state_dict["total_batches"]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class EpochTimerConfig(CallbackConfigBase):
|
|
153
|
+
name: Literal["epoch_timer"] = "epoch_timer"
|
|
154
|
+
|
|
155
|
+
@override
|
|
156
|
+
def construct_callbacks(self, root_config):
|
|
157
|
+
yield EpochTimer()
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from logging import getLogger
|
|
2
|
+
from typing import Literal, Protocol, cast, runtime_checkable
|
|
3
|
+
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from lightning.pytorch import LightningModule, Trainer
|
|
6
|
+
from lightning.pytorch.callbacks.callback import Callback
|
|
7
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from .base import CallbackConfigBase
|
|
11
|
+
|
|
12
|
+
log = getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@runtime_checkable
|
|
16
|
+
class _HasWandbLogModuleProtocol(Protocol):
|
|
17
|
+
def wandb_log_module(self) -> nn.Module | None: ...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class WandbWatchCallback(Callback):
|
|
21
|
+
def __init__(self, config: "WandbWatchConfig"):
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
self.config = config
|
|
25
|
+
|
|
26
|
+
@override
|
|
27
|
+
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
28
|
+
self._on_start(trainer, pl_module)
|
|
29
|
+
|
|
30
|
+
@override
|
|
31
|
+
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
32
|
+
self._on_start(trainer, pl_module)
|
|
33
|
+
|
|
34
|
+
@override
|
|
35
|
+
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
36
|
+
self._on_start(trainer, pl_module)
|
|
37
|
+
|
|
38
|
+
@override
|
|
39
|
+
def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
40
|
+
self._on_start(trainer, pl_module)
|
|
41
|
+
|
|
42
|
+
def _on_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
|
43
|
+
# If not enabled, return
|
|
44
|
+
if not self.config:
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
# If we're in fast_dev_run, don't watch the model
|
|
48
|
+
if getattr(trainer, "fast_dev_run", False):
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
if (
|
|
52
|
+
logger := next(
|
|
53
|
+
(
|
|
54
|
+
logger
|
|
55
|
+
for logger in trainer.loggers
|
|
56
|
+
if isinstance(logger, WandbLogger)
|
|
57
|
+
),
|
|
58
|
+
None,
|
|
59
|
+
)
|
|
60
|
+
) is None:
|
|
61
|
+
log.warning("Could not find wandb logger or module to log")
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
if getattr(pl_module, "_model_watched", False):
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
# Get which module to log
|
|
68
|
+
if (
|
|
69
|
+
not isinstance(pl_module, _HasWandbLogModuleProtocol)
|
|
70
|
+
or (module := pl_module.wandb_log_module()) is None
|
|
71
|
+
):
|
|
72
|
+
module = cast(nn.Module, pl_module)
|
|
73
|
+
|
|
74
|
+
logger.watch(
|
|
75
|
+
module,
|
|
76
|
+
log=cast(str, self.config.log),
|
|
77
|
+
log_freq=self.config.log_freq,
|
|
78
|
+
log_graph=self.config.log_graph,
|
|
79
|
+
)
|
|
80
|
+
setattr(pl_module, "_model_watched", True)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class WandbWatchConfig(CallbackConfigBase):
|
|
84
|
+
name: Literal["finite_checks"] = "finite_checks"
|
|
85
|
+
|
|
86
|
+
enabled: bool = True
|
|
87
|
+
"""Enable watching the model for wandb."""
|
|
88
|
+
|
|
89
|
+
log: str | None = None
|
|
90
|
+
"""Log type for wandb."""
|
|
91
|
+
|
|
92
|
+
log_graph: bool = True
|
|
93
|
+
"""Whether to log the graph for wandb."""
|
|
94
|
+
|
|
95
|
+
log_freq: int = 100
|
|
96
|
+
"""Log frequency for wandb."""
|
|
97
|
+
|
|
98
|
+
def __bool__(self):
|
|
99
|
+
return self.enabled
|
|
100
|
+
|
|
101
|
+
@override
|
|
102
|
+
def construct_callbacks(self, root_config):
|
|
103
|
+
yield WandbWatchCallback(self)
|
nshtrainer/config.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
from collections.abc import Mapping, MutableMapping
|
|
2
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, ConfigDict
|
|
5
|
+
from pydantic import Field as Field
|
|
6
|
+
from pydantic import PrivateAttr as PrivateAttr
|
|
7
|
+
from typing_extensions import deprecated, override
|
|
8
|
+
|
|
9
|
+
from ._config.missing import MISSING, validate_no_missing_values
|
|
10
|
+
from ._config.missing import AllowMissing as AllowMissing
|
|
11
|
+
from ._config.missing import MissingField as MissingField
|
|
12
|
+
|
|
13
|
+
_MutableMappingBase = MutableMapping[str, Any]
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
_MutableMappingBase = object
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
_DraftConfigContextSentinel = object()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TypedConfig(BaseModel, _MutableMappingBase):
|
|
22
|
+
_is_draft_config: bool = PrivateAttr(default=False)
|
|
23
|
+
"""
|
|
24
|
+
Whether this config is a draft config or not.
|
|
25
|
+
|
|
26
|
+
Draft configs are configs that are not yet fully validated.
|
|
27
|
+
They allow for a nicer API when creating configs, e.g.:
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
config = MyConfig.draft()
|
|
31
|
+
|
|
32
|
+
# Set some values
|
|
33
|
+
config.a = 10
|
|
34
|
+
config.b = "hello"
|
|
35
|
+
|
|
36
|
+
# Finalize the config
|
|
37
|
+
config = config.finalize()
|
|
38
|
+
```
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
repr_diff_only: ClassVar[bool] = True
|
|
42
|
+
"""
|
|
43
|
+
If `True`, the repr methods will only show values for fields that are different from the default.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
MISSING: ClassVar[Any] = MISSING
|
|
47
|
+
"""
|
|
48
|
+
Alias for the `MISSING` constant.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
model_config: ClassVar[ConfigDict] = ConfigDict(
|
|
52
|
+
# By default, Pydantic will throw a warning if a field starts with "model_",
|
|
53
|
+
# so we need to disable that warning (beacuse "model_" is a popular prefix for ML).
|
|
54
|
+
protected_namespaces=(),
|
|
55
|
+
validate_assignment=True,
|
|
56
|
+
validate_return=True,
|
|
57
|
+
validate_default=True,
|
|
58
|
+
strict=True,
|
|
59
|
+
revalidate_instances="always",
|
|
60
|
+
arbitrary_types_allowed=True,
|
|
61
|
+
extra="ignore",
|
|
62
|
+
validation_error_cause=True,
|
|
63
|
+
use_attribute_docstrings=True,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def __draft_pre_init__(self):
|
|
67
|
+
"""Called right before a draft config is finalized."""
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
def __post_init__(self):
|
|
71
|
+
"""Called after the final config is validated."""
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
@deprecated("Use `model_validate` instead.")
|
|
76
|
+
def from_dict(cls, model_dict: Mapping[str, Any]):
|
|
77
|
+
return cls.model_validate(model_dict)
|
|
78
|
+
|
|
79
|
+
def model_deep_validate(self, strict: bool = True):
|
|
80
|
+
"""
|
|
81
|
+
Validate the config and all of its sub-configs.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
config: The config to validate.
|
|
85
|
+
strict: Whether to validate the config strictly.
|
|
86
|
+
"""
|
|
87
|
+
config_dict = self.model_dump(round_trip=True)
|
|
88
|
+
config = self.model_validate(config_dict, strict=strict)
|
|
89
|
+
|
|
90
|
+
# Make sure that this is not a draft config
|
|
91
|
+
if config._is_draft_config:
|
|
92
|
+
raise ValueError("Draft configs are not valid. Call `finalize` first.")
|
|
93
|
+
|
|
94
|
+
return config
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def draft(cls, **kwargs):
|
|
98
|
+
config = cls.model_construct_draft(**kwargs)
|
|
99
|
+
return config
|
|
100
|
+
|
|
101
|
+
def finalize(self, strict: bool = True):
|
|
102
|
+
# This must be a draft config, otherwise we raise an error
|
|
103
|
+
if not self._is_draft_config:
|
|
104
|
+
raise ValueError("Finalize can only be called on drafts.")
|
|
105
|
+
|
|
106
|
+
# First, we call `__draft_pre_init__` to allow the config to modify itself a final time
|
|
107
|
+
self.__draft_pre_init__()
|
|
108
|
+
|
|
109
|
+
# Then, we dump the config to a dict and then re-validate it
|
|
110
|
+
return self.model_deep_validate(strict=strict)
|
|
111
|
+
|
|
112
|
+
@override
|
|
113
|
+
def model_post_init(self, __context: Any) -> None:
|
|
114
|
+
super().model_post_init(__context)
|
|
115
|
+
|
|
116
|
+
# Call the `__post_init__` method if this is not a draft config
|
|
117
|
+
if __context is _DraftConfigContextSentinel:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
self.__post_init__()
|
|
121
|
+
|
|
122
|
+
# After `_post_init__` is called, we perform the final round of validation
|
|
123
|
+
self.model_post_init_validate()
|
|
124
|
+
|
|
125
|
+
def model_post_init_validate(self):
|
|
126
|
+
validate_no_missing_values(self)
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def model_construct_draft(cls, _fields_set: set[str] | None = None, **values: Any):
|
|
130
|
+
"""
|
|
131
|
+
NOTE: This is a copy of the `model_construct` method from Pydantic's `Model` class,
|
|
132
|
+
with the following changes:
|
|
133
|
+
- The `model_post_init` method is called with the `_DraftConfigContext` context.
|
|
134
|
+
- The `_is_draft_config` attribute is set to `True` in the `values` dict.
|
|
135
|
+
|
|
136
|
+
Creates a new instance of the `Model` class with validated data.
|
|
137
|
+
|
|
138
|
+
Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
|
|
139
|
+
Default values are respected, but no other validation is performed.
|
|
140
|
+
|
|
141
|
+
!!! note
|
|
142
|
+
`model_construct()` generally respects the `model_config.extra` setting on the provided model.
|
|
143
|
+
That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
|
|
144
|
+
and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
|
|
145
|
+
Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
|
|
146
|
+
an error if extra values are passed, but they will be ignored.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
_fields_set: The set of field names accepted for the Model instance.
|
|
150
|
+
values: Trusted or pre-validated data dictionary.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
A new instance of the `Model` class with validated data.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
values["_is_draft_config"] = True
|
|
157
|
+
|
|
158
|
+
m = cls.__new__(cls)
|
|
159
|
+
fields_values: dict[str, Any] = {}
|
|
160
|
+
fields_set = set()
|
|
161
|
+
|
|
162
|
+
for name, field in cls.model_fields.items():
|
|
163
|
+
if field.alias and field.alias in values:
|
|
164
|
+
fields_values[name] = values.pop(field.alias)
|
|
165
|
+
fields_set.add(name)
|
|
166
|
+
elif name in values:
|
|
167
|
+
fields_values[name] = values.pop(name)
|
|
168
|
+
fields_set.add(name)
|
|
169
|
+
elif not field.is_required():
|
|
170
|
+
fields_values[name] = field.get_default(call_default_factory=True)
|
|
171
|
+
if _fields_set is None:
|
|
172
|
+
_fields_set = fields_set
|
|
173
|
+
|
|
174
|
+
_extra: dict[str, Any] | None = None
|
|
175
|
+
if cls.model_config.get("extra") == "allow":
|
|
176
|
+
_extra = {}
|
|
177
|
+
for k, v in values.items():
|
|
178
|
+
_extra[k] = v
|
|
179
|
+
object.__setattr__(m, "__dict__", fields_values)
|
|
180
|
+
object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
|
|
181
|
+
if not cls.__pydantic_root_model__:
|
|
182
|
+
object.__setattr__(m, "__pydantic_extra__", _extra)
|
|
183
|
+
|
|
184
|
+
if cls.__pydantic_post_init__:
|
|
185
|
+
m.model_post_init(_DraftConfigContextSentinel)
|
|
186
|
+
# update private attributes with values set
|
|
187
|
+
if (
|
|
188
|
+
hasattr(m, "__pydantic_private__")
|
|
189
|
+
and m.__pydantic_private__ is not None
|
|
190
|
+
):
|
|
191
|
+
for k, v in values.items():
|
|
192
|
+
if k in m.__private_attributes__:
|
|
193
|
+
m.__pydantic_private__[k] = v
|
|
194
|
+
|
|
195
|
+
elif not cls.__pydantic_root_model__:
|
|
196
|
+
# Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
|
|
197
|
+
# Since it doesn't, that means that `__pydantic_private__` should be set to None
|
|
198
|
+
object.__setattr__(m, "__pydantic_private__", None)
|
|
199
|
+
|
|
200
|
+
return m
|
|
201
|
+
|
|
202
|
+
@override
|
|
203
|
+
def __repr_args__(self):
|
|
204
|
+
# If `repr_diff_only` is `True`, we only show the fields that are different from the default.
|
|
205
|
+
if not self.repr_diff_only:
|
|
206
|
+
yield from super().__repr_args__()
|
|
207
|
+
return
|
|
208
|
+
|
|
209
|
+
# First, we get the default values for all fields.
|
|
210
|
+
default_values = self.model_construct_draft()
|
|
211
|
+
|
|
212
|
+
# Then, we compare the default values with the current values.
|
|
213
|
+
for k, v in super().__repr_args__():
|
|
214
|
+
if k is None:
|
|
215
|
+
yield k, v
|
|
216
|
+
continue
|
|
217
|
+
|
|
218
|
+
# If there is no default value or the value is different from the default, we yield it.
|
|
219
|
+
if not hasattr(default_values, k) or getattr(default_values, k) != v:
|
|
220
|
+
yield k, v
|
|
221
|
+
continue
|
|
222
|
+
|
|
223
|
+
# Otherwise, we can skip this field.
|
|
224
|
+
|
|
225
|
+
# region MutableMapping implementation
|
|
226
|
+
if not TYPE_CHECKING:
|
|
227
|
+
# This is mainly so the config can be used with lightning's hparams
|
|
228
|
+
# transparently and without any issues.
|
|
229
|
+
|
|
230
|
+
@property
|
|
231
|
+
def _ll_dict(self):
|
|
232
|
+
return self.model_dump()
|
|
233
|
+
|
|
234
|
+
# We need to make sure every config class
|
|
235
|
+
# is a MutableMapping[str, Any] so that it can be used
|
|
236
|
+
# with lightning's hparams.
|
|
237
|
+
@override
|
|
238
|
+
def __getitem__(self, key: str):
|
|
239
|
+
# Key can be of the format "a.b.c"
|
|
240
|
+
# so we need to split it into a list of keys.
|
|
241
|
+
[first_key, *rest_keys] = key.split(".")
|
|
242
|
+
value = self._ll_dict[first_key]
|
|
243
|
+
|
|
244
|
+
for key in rest_keys:
|
|
245
|
+
if isinstance(value, Mapping):
|
|
246
|
+
value = value[key]
|
|
247
|
+
else:
|
|
248
|
+
value = getattr(value, key)
|
|
249
|
+
|
|
250
|
+
return value
|
|
251
|
+
|
|
252
|
+
@override
|
|
253
|
+
def __setitem__(self, key: str, value: Any):
|
|
254
|
+
# Key can be of the format "a.b.c"
|
|
255
|
+
# so we need to split it into a list of keys.
|
|
256
|
+
[first_key, *rest_keys] = key.split(".")
|
|
257
|
+
if len(rest_keys) == 0:
|
|
258
|
+
self._ll_dict[first_key] = value
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
# We need to traverse the keys until we reach the last key
|
|
262
|
+
# and then set the value
|
|
263
|
+
current_value = self._ll_dict[first_key]
|
|
264
|
+
for key in rest_keys[:-1]:
|
|
265
|
+
if isinstance(current_value, Mapping):
|
|
266
|
+
current_value = current_value[key]
|
|
267
|
+
else:
|
|
268
|
+
current_value = getattr(current_value, key)
|
|
269
|
+
|
|
270
|
+
# Set the value
|
|
271
|
+
if isinstance(current_value, MutableMapping):
|
|
272
|
+
current_value[rest_keys[-1]] = value
|
|
273
|
+
else:
|
|
274
|
+
setattr(current_value, rest_keys[-1], value)
|
|
275
|
+
|
|
276
|
+
@override
|
|
277
|
+
def __delitem__(self, key: str):
|
|
278
|
+
# This is unsupported for this class
|
|
279
|
+
raise NotImplementedError
|
|
280
|
+
|
|
281
|
+
@override
|
|
282
|
+
def __iter__(self):
|
|
283
|
+
return iter(self._ll_dict)
|
|
284
|
+
|
|
285
|
+
@override
|
|
286
|
+
def __len__(self):
|
|
287
|
+
return len(self._ll_dict)
|
|
288
|
+
|
|
289
|
+
# endregion
|