nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,90 @@
1
+ import copy
2
+ import fnmatch
3
+ import importlib.util
4
+ import logging
5
+ from typing import Literal
6
+
7
+ import torch
8
+ from lightning.pytorch import LightningModule, Trainer
9
+ from lightning.pytorch.callbacks import Callback
10
+ from typing_extensions import override
11
+
12
+ from .base import CallbackConfigBase
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class PrintTableMetricsCallback(Callback):
18
+ """Prints a table with the metrics in columns on every epoch end."""
19
+
20
+ def __init__(
21
+ self,
22
+ metric_patterns: list[str] | None = None,
23
+ ) -> None:
24
+ self.metrics: list = []
25
+ self.rich_available = importlib.util.find_spec("rich") is not None
26
+ self.metric_patterns = metric_patterns
27
+
28
+ if not self.rich_available:
29
+ log.warning(
30
+ "rich is not installed. Please install it to use PrintTableMetricsCallback."
31
+ )
32
+
33
+ @override
34
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
35
+ if not self.rich_available:
36
+ return
37
+
38
+ metrics_dict = copy.copy(trainer.callback_metrics)
39
+ # Filter metrics based on the patterns
40
+ if self.metric_patterns is not None:
41
+ metrics_dict = {
42
+ key: value
43
+ for key, value in metrics_dict.items()
44
+ if any(
45
+ fnmatch.fnmatch(key, pattern) for pattern in self.metric_patterns
46
+ )
47
+ }
48
+ self.metrics.append(metrics_dict)
49
+
50
+ from rich.console import Console
51
+
52
+ console = Console()
53
+ table = self.create_metrics_table()
54
+ console.print(table)
55
+
56
+ def create_metrics_table(self):
57
+ from rich.table import Table
58
+
59
+ table = Table(show_header=True, header_style="bold magenta")
60
+
61
+ # Add columns to the table based on the keys in the first metrics dictionary
62
+ for key in self.metrics[0].keys():
63
+ table.add_column(key)
64
+
65
+ # Add rows to the table based on the metrics dictionaries
66
+ for metric_dict in self.metrics:
67
+ values: list[str] = []
68
+ for value in metric_dict.values():
69
+ if torch.is_tensor(value):
70
+ value = float(value.item())
71
+ values.append(str(value))
72
+ table.add_row(*values)
73
+
74
+ return table
75
+
76
+
77
+ class PrintTableMetricsConfig(CallbackConfigBase):
78
+ """Configuration class for PrintTableMetricsCallback."""
79
+
80
+ name: Literal["print_table_metrics"] = "print_table_metrics"
81
+
82
+ enabled: bool = True
83
+ """Whether to enable the callback or not."""
84
+
85
+ metric_patterns: list[str] | None = None
86
+ """List of patterns to filter the metrics to be displayed. If None, all metrics are displayed."""
87
+
88
+ @override
89
+ def construct_callbacks(self, root_config):
90
+ yield PrintTableMetricsCallback(metric_patterns=self.metric_patterns)
@@ -0,0 +1,56 @@
1
+ from logging import getLogger
2
+ from typing import Any, Literal, Protocol, TypedDict, cast, runtime_checkable
3
+
4
+ from typing_extensions import NotRequired, override
5
+
6
+ from ._throughput_monitor_callback import ThroughputMonitor as _ThroughputMonitor
7
+ from .base import CallbackConfigBase
8
+
9
+ log = getLogger(__name__)
10
+
11
+
12
+ class ThroughputMonitorBatchStats(TypedDict):
13
+ batch_size: int
14
+ length: NotRequired[int | None]
15
+
16
+
17
+ @runtime_checkable
18
+ class SupportsThroughputMonitorModuleProtocol(Protocol):
19
+ def throughput_monitor_batch_stats(
20
+ self, batch: Any
21
+ ) -> ThroughputMonitorBatchStats: ...
22
+
23
+
24
+ class ThroughputMonitor(_ThroughputMonitor):
25
+ def __init__(self, window_size: int = 100) -> None:
26
+ super().__init__(cast(Any, None), cast(Any, None), window_size=window_size)
27
+
28
+ @override
29
+ def setup(self, trainer, pl_module, stage):
30
+ if not isinstance(pl_module, SupportsThroughputMonitorModuleProtocol):
31
+ raise RuntimeError(
32
+ "The model does not implement `throughput_monitor_batch_stats`. "
33
+ "Please either implement this method, or do not use the `ThroughputMonitor` callback."
34
+ )
35
+
36
+ def batch_size_fn(batch):
37
+ return pl_module.throughput_monitor_batch_stats(batch)["batch_size"]
38
+
39
+ def length_fn(batch):
40
+ return pl_module.throughput_monitor_batch_stats(batch).get("length")
41
+
42
+ self.batch_size_fn = batch_size_fn
43
+ self.length_fn = length_fn
44
+
45
+ return super().setup(trainer, pl_module, stage)
46
+
47
+
48
+ class ThroughputMonitorConfig(CallbackConfigBase):
49
+ name: Literal["throughput_monitor"] = "throughput_monitor"
50
+
51
+ window_size: int = 100
52
+ """Number of batches to use for a rolling average."""
53
+
54
+ @override
55
+ def construct_callbacks(self, root_config):
56
+ yield ThroughputMonitor(window_size=self.window_size)
@@ -0,0 +1,157 @@
1
+ import logging
2
+ import time
3
+ from typing import Any, Literal
4
+
5
+ from lightning.pytorch import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Callback
7
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
8
+ from typing_extensions import override
9
+
10
+ from .base import CallbackConfigBase
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ class EpochTimer(Callback):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ self._start_time: dict[str, float] = {}
20
+ self._elapsed_time: dict[str, float] = {}
21
+ self._total_batches: dict[str, int] = {}
22
+
23
+ @override
24
+ def on_train_epoch_start(
25
+ self, trainer: "Trainer", pl_module: "LightningModule"
26
+ ) -> None:
27
+ self._start_time["train"] = time.monotonic()
28
+ self._total_batches["train"] = 0
29
+
30
+ @override
31
+ def on_train_batch_end(
32
+ self,
33
+ trainer: "Trainer",
34
+ pl_module: "LightningModule",
35
+ outputs: STEP_OUTPUT,
36
+ batch: Any,
37
+ batch_idx: int,
38
+ ) -> None:
39
+ self._total_batches["train"] += 1
40
+
41
+ @override
42
+ def on_train_epoch_end(
43
+ self, trainer: "Trainer", pl_module: "LightningModule"
44
+ ) -> None:
45
+ self._elapsed_time["train"] = time.monotonic() - self._start_time["train"]
46
+ if trainer.is_global_zero:
47
+ self._log_epoch_info("train")
48
+
49
+ @override
50
+ def on_validation_epoch_start(
51
+ self, trainer: "Trainer", pl_module: "LightningModule"
52
+ ) -> None:
53
+ self._start_time["val"] = time.monotonic()
54
+ self._total_batches["val"] = 0
55
+
56
+ @override
57
+ def on_validation_batch_end(
58
+ self,
59
+ trainer: "Trainer",
60
+ pl_module: "LightningModule",
61
+ outputs: STEP_OUTPUT,
62
+ batch: Any,
63
+ batch_idx: int,
64
+ dataloader_idx: int = 0,
65
+ ) -> None:
66
+ self._total_batches["val"] += 1
67
+
68
+ @override
69
+ def on_validation_epoch_end(
70
+ self, trainer: "Trainer", pl_module: "LightningModule"
71
+ ) -> None:
72
+ self._elapsed_time["val"] = time.monotonic() - self._start_time["val"]
73
+ if trainer.is_global_zero:
74
+ self._log_epoch_info("val")
75
+
76
+ @override
77
+ def on_test_epoch_start(
78
+ self, trainer: "Trainer", pl_module: "LightningModule"
79
+ ) -> None:
80
+ self._start_time["test"] = time.monotonic()
81
+ self._total_batches["test"] = 0
82
+
83
+ @override
84
+ def on_test_batch_end(
85
+ self,
86
+ trainer: "Trainer",
87
+ pl_module: "LightningModule",
88
+ outputs: STEP_OUTPUT,
89
+ batch: Any,
90
+ batch_idx: int,
91
+ dataloader_idx: int = 0,
92
+ ) -> None:
93
+ self._total_batches["test"] += 1
94
+
95
+ @override
96
+ def on_test_epoch_end(
97
+ self, trainer: "Trainer", pl_module: "LightningModule"
98
+ ) -> None:
99
+ self._elapsed_time["test"] = time.monotonic() - self._start_time["test"]
100
+ if trainer.is_global_zero:
101
+ self._log_epoch_info("test")
102
+
103
+ @override
104
+ def on_predict_epoch_start(
105
+ self, trainer: "Trainer", pl_module: "LightningModule"
106
+ ) -> None:
107
+ self._start_time["predict"] = time.monotonic()
108
+ self._total_batches["predict"] = 0
109
+
110
+ @override
111
+ def on_predict_batch_end(
112
+ self,
113
+ trainer: "Trainer",
114
+ pl_module: "LightningModule",
115
+ outputs: STEP_OUTPUT,
116
+ batch: Any,
117
+ batch_idx: int,
118
+ dataloader_idx: int = 0,
119
+ ) -> None:
120
+ self._total_batches["predict"] += 1
121
+
122
+ @override
123
+ def on_predict_epoch_end(
124
+ self, trainer: "Trainer", pl_module: "LightningModule"
125
+ ) -> None:
126
+ self._elapsed_time["predict"] = time.monotonic() - self._start_time["predict"]
127
+ if trainer.is_global_zero:
128
+ self._log_epoch_info("predict")
129
+
130
+ def _log_epoch_info(self, stage: str) -> None:
131
+ if (elapsed_time := self._elapsed_time.get(stage)) is None:
132
+ return
133
+ total_batches = self._total_batches[stage]
134
+ log.critical(
135
+ f"Epoch {stage.capitalize()} Summary: Elapsed Time: {elapsed_time:.2f} seconds | "
136
+ f"Total Batches: {total_batches}"
137
+ )
138
+
139
+ @override
140
+ def state_dict(self) -> dict[str, Any]:
141
+ return {
142
+ "elapsed_time": self._elapsed_time,
143
+ "total_batches": self._total_batches,
144
+ }
145
+
146
+ @override
147
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
148
+ self._elapsed_time = state_dict["elapsed_time"]
149
+ self._total_batches = state_dict["total_batches"]
150
+
151
+
152
+ class EpochTimerConfig(CallbackConfigBase):
153
+ name: Literal["epoch_timer"] = "epoch_timer"
154
+
155
+ @override
156
+ def construct_callbacks(self, root_config):
157
+ yield EpochTimer()
@@ -0,0 +1,103 @@
1
+ from logging import getLogger
2
+ from typing import Literal, Protocol, cast, runtime_checkable
3
+
4
+ import torch.nn as nn
5
+ from lightning.pytorch import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks.callback import Callback
7
+ from lightning.pytorch.loggers import WandbLogger
8
+ from typing_extensions import override
9
+
10
+ from .base import CallbackConfigBase
11
+
12
+ log = getLogger(__name__)
13
+
14
+
15
+ @runtime_checkable
16
+ class _HasWandbLogModuleProtocol(Protocol):
17
+ def wandb_log_module(self) -> nn.Module | None: ...
18
+
19
+
20
+ class WandbWatchCallback(Callback):
21
+ def __init__(self, config: "WandbWatchConfig"):
22
+ super().__init__()
23
+
24
+ self.config = config
25
+
26
+ @override
27
+ def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
28
+ self._on_start(trainer, pl_module)
29
+
30
+ @override
31
+ def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
32
+ self._on_start(trainer, pl_module)
33
+
34
+ @override
35
+ def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
36
+ self._on_start(trainer, pl_module)
37
+
38
+ @override
39
+ def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
40
+ self._on_start(trainer, pl_module)
41
+
42
+ def _on_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
43
+ # If not enabled, return
44
+ if not self.config:
45
+ return
46
+
47
+ # If we're in fast_dev_run, don't watch the model
48
+ if getattr(trainer, "fast_dev_run", False):
49
+ return
50
+
51
+ if (
52
+ logger := next(
53
+ (
54
+ logger
55
+ for logger in trainer.loggers
56
+ if isinstance(logger, WandbLogger)
57
+ ),
58
+ None,
59
+ )
60
+ ) is None:
61
+ log.warning("Could not find wandb logger or module to log")
62
+ return
63
+
64
+ if getattr(pl_module, "_model_watched", False):
65
+ return
66
+
67
+ # Get which module to log
68
+ if (
69
+ not isinstance(pl_module, _HasWandbLogModuleProtocol)
70
+ or (module := pl_module.wandb_log_module()) is None
71
+ ):
72
+ module = cast(nn.Module, pl_module)
73
+
74
+ logger.watch(
75
+ module,
76
+ log=cast(str, self.config.log),
77
+ log_freq=self.config.log_freq,
78
+ log_graph=self.config.log_graph,
79
+ )
80
+ setattr(pl_module, "_model_watched", True)
81
+
82
+
83
+ class WandbWatchConfig(CallbackConfigBase):
84
+ name: Literal["finite_checks"] = "finite_checks"
85
+
86
+ enabled: bool = True
87
+ """Enable watching the model for wandb."""
88
+
89
+ log: str | None = None
90
+ """Log type for wandb."""
91
+
92
+ log_graph: bool = True
93
+ """Whether to log the graph for wandb."""
94
+
95
+ log_freq: int = 100
96
+ """Log frequency for wandb."""
97
+
98
+ def __bool__(self):
99
+ return self.enabled
100
+
101
+ @override
102
+ def construct_callbacks(self, root_config):
103
+ yield WandbWatchCallback(self)
nshtrainer/config.py ADDED
@@ -0,0 +1,289 @@
1
+ from collections.abc import Mapping, MutableMapping
2
+ from typing import TYPE_CHECKING, Any, ClassVar
3
+
4
+ from pydantic import BaseModel, ConfigDict
5
+ from pydantic import Field as Field
6
+ from pydantic import PrivateAttr as PrivateAttr
7
+ from typing_extensions import deprecated, override
8
+
9
+ from ._config.missing import MISSING, validate_no_missing_values
10
+ from ._config.missing import AllowMissing as AllowMissing
11
+ from ._config.missing import MissingField as MissingField
12
+
13
+ _MutableMappingBase = MutableMapping[str, Any]
14
+ if TYPE_CHECKING:
15
+ _MutableMappingBase = object
16
+
17
+
18
+ _DraftConfigContextSentinel = object()
19
+
20
+
21
+ class TypedConfig(BaseModel, _MutableMappingBase):
22
+ _is_draft_config: bool = PrivateAttr(default=False)
23
+ """
24
+ Whether this config is a draft config or not.
25
+
26
+ Draft configs are configs that are not yet fully validated.
27
+ They allow for a nicer API when creating configs, e.g.:
28
+
29
+ ```python
30
+ config = MyConfig.draft()
31
+
32
+ # Set some values
33
+ config.a = 10
34
+ config.b = "hello"
35
+
36
+ # Finalize the config
37
+ config = config.finalize()
38
+ ```
39
+ """
40
+
41
+ repr_diff_only: ClassVar[bool] = True
42
+ """
43
+ If `True`, the repr methods will only show values for fields that are different from the default.
44
+ """
45
+
46
+ MISSING: ClassVar[Any] = MISSING
47
+ """
48
+ Alias for the `MISSING` constant.
49
+ """
50
+
51
+ model_config: ClassVar[ConfigDict] = ConfigDict(
52
+ # By default, Pydantic will throw a warning if a field starts with "model_",
53
+ # so we need to disable that warning (beacuse "model_" is a popular prefix for ML).
54
+ protected_namespaces=(),
55
+ validate_assignment=True,
56
+ validate_return=True,
57
+ validate_default=True,
58
+ strict=True,
59
+ revalidate_instances="always",
60
+ arbitrary_types_allowed=True,
61
+ extra="ignore",
62
+ validation_error_cause=True,
63
+ use_attribute_docstrings=True,
64
+ )
65
+
66
+ def __draft_pre_init__(self):
67
+ """Called right before a draft config is finalized."""
68
+ pass
69
+
70
+ def __post_init__(self):
71
+ """Called after the final config is validated."""
72
+ pass
73
+
74
+ @classmethod
75
+ @deprecated("Use `model_validate` instead.")
76
+ def from_dict(cls, model_dict: Mapping[str, Any]):
77
+ return cls.model_validate(model_dict)
78
+
79
+ def model_deep_validate(self, strict: bool = True):
80
+ """
81
+ Validate the config and all of its sub-configs.
82
+
83
+ Args:
84
+ config: The config to validate.
85
+ strict: Whether to validate the config strictly.
86
+ """
87
+ config_dict = self.model_dump(round_trip=True)
88
+ config = self.model_validate(config_dict, strict=strict)
89
+
90
+ # Make sure that this is not a draft config
91
+ if config._is_draft_config:
92
+ raise ValueError("Draft configs are not valid. Call `finalize` first.")
93
+
94
+ return config
95
+
96
+ @classmethod
97
+ def draft(cls, **kwargs):
98
+ config = cls.model_construct_draft(**kwargs)
99
+ return config
100
+
101
+ def finalize(self, strict: bool = True):
102
+ # This must be a draft config, otherwise we raise an error
103
+ if not self._is_draft_config:
104
+ raise ValueError("Finalize can only be called on drafts.")
105
+
106
+ # First, we call `__draft_pre_init__` to allow the config to modify itself a final time
107
+ self.__draft_pre_init__()
108
+
109
+ # Then, we dump the config to a dict and then re-validate it
110
+ return self.model_deep_validate(strict=strict)
111
+
112
+ @override
113
+ def model_post_init(self, __context: Any) -> None:
114
+ super().model_post_init(__context)
115
+
116
+ # Call the `__post_init__` method if this is not a draft config
117
+ if __context is _DraftConfigContextSentinel:
118
+ return
119
+
120
+ self.__post_init__()
121
+
122
+ # After `_post_init__` is called, we perform the final round of validation
123
+ self.model_post_init_validate()
124
+
125
+ def model_post_init_validate(self):
126
+ validate_no_missing_values(self)
127
+
128
+ @classmethod
129
+ def model_construct_draft(cls, _fields_set: set[str] | None = None, **values: Any):
130
+ """
131
+ NOTE: This is a copy of the `model_construct` method from Pydantic's `Model` class,
132
+ with the following changes:
133
+ - The `model_post_init` method is called with the `_DraftConfigContext` context.
134
+ - The `_is_draft_config` attribute is set to `True` in the `values` dict.
135
+
136
+ Creates a new instance of the `Model` class with validated data.
137
+
138
+ Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data.
139
+ Default values are respected, but no other validation is performed.
140
+
141
+ !!! note
142
+ `model_construct()` generally respects the `model_config.extra` setting on the provided model.
143
+ That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__`
144
+ and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored.
145
+ Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in
146
+ an error if extra values are passed, but they will be ignored.
147
+
148
+ Args:
149
+ _fields_set: The set of field names accepted for the Model instance.
150
+ values: Trusted or pre-validated data dictionary.
151
+
152
+ Returns:
153
+ A new instance of the `Model` class with validated data.
154
+ """
155
+
156
+ values["_is_draft_config"] = True
157
+
158
+ m = cls.__new__(cls)
159
+ fields_values: dict[str, Any] = {}
160
+ fields_set = set()
161
+
162
+ for name, field in cls.model_fields.items():
163
+ if field.alias and field.alias in values:
164
+ fields_values[name] = values.pop(field.alias)
165
+ fields_set.add(name)
166
+ elif name in values:
167
+ fields_values[name] = values.pop(name)
168
+ fields_set.add(name)
169
+ elif not field.is_required():
170
+ fields_values[name] = field.get_default(call_default_factory=True)
171
+ if _fields_set is None:
172
+ _fields_set = fields_set
173
+
174
+ _extra: dict[str, Any] | None = None
175
+ if cls.model_config.get("extra") == "allow":
176
+ _extra = {}
177
+ for k, v in values.items():
178
+ _extra[k] = v
179
+ object.__setattr__(m, "__dict__", fields_values)
180
+ object.__setattr__(m, "__pydantic_fields_set__", _fields_set)
181
+ if not cls.__pydantic_root_model__:
182
+ object.__setattr__(m, "__pydantic_extra__", _extra)
183
+
184
+ if cls.__pydantic_post_init__:
185
+ m.model_post_init(_DraftConfigContextSentinel)
186
+ # update private attributes with values set
187
+ if (
188
+ hasattr(m, "__pydantic_private__")
189
+ and m.__pydantic_private__ is not None
190
+ ):
191
+ for k, v in values.items():
192
+ if k in m.__private_attributes__:
193
+ m.__pydantic_private__[k] = v
194
+
195
+ elif not cls.__pydantic_root_model__:
196
+ # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
197
+ # Since it doesn't, that means that `__pydantic_private__` should be set to None
198
+ object.__setattr__(m, "__pydantic_private__", None)
199
+
200
+ return m
201
+
202
+ @override
203
+ def __repr_args__(self):
204
+ # If `repr_diff_only` is `True`, we only show the fields that are different from the default.
205
+ if not self.repr_diff_only:
206
+ yield from super().__repr_args__()
207
+ return
208
+
209
+ # First, we get the default values for all fields.
210
+ default_values = self.model_construct_draft()
211
+
212
+ # Then, we compare the default values with the current values.
213
+ for k, v in super().__repr_args__():
214
+ if k is None:
215
+ yield k, v
216
+ continue
217
+
218
+ # If there is no default value or the value is different from the default, we yield it.
219
+ if not hasattr(default_values, k) or getattr(default_values, k) != v:
220
+ yield k, v
221
+ continue
222
+
223
+ # Otherwise, we can skip this field.
224
+
225
+ # region MutableMapping implementation
226
+ if not TYPE_CHECKING:
227
+ # This is mainly so the config can be used with lightning's hparams
228
+ # transparently and without any issues.
229
+
230
+ @property
231
+ def _ll_dict(self):
232
+ return self.model_dump()
233
+
234
+ # We need to make sure every config class
235
+ # is a MutableMapping[str, Any] so that it can be used
236
+ # with lightning's hparams.
237
+ @override
238
+ def __getitem__(self, key: str):
239
+ # Key can be of the format "a.b.c"
240
+ # so we need to split it into a list of keys.
241
+ [first_key, *rest_keys] = key.split(".")
242
+ value = self._ll_dict[first_key]
243
+
244
+ for key in rest_keys:
245
+ if isinstance(value, Mapping):
246
+ value = value[key]
247
+ else:
248
+ value = getattr(value, key)
249
+
250
+ return value
251
+
252
+ @override
253
+ def __setitem__(self, key: str, value: Any):
254
+ # Key can be of the format "a.b.c"
255
+ # so we need to split it into a list of keys.
256
+ [first_key, *rest_keys] = key.split(".")
257
+ if len(rest_keys) == 0:
258
+ self._ll_dict[first_key] = value
259
+ return
260
+
261
+ # We need to traverse the keys until we reach the last key
262
+ # and then set the value
263
+ current_value = self._ll_dict[first_key]
264
+ for key in rest_keys[:-1]:
265
+ if isinstance(current_value, Mapping):
266
+ current_value = current_value[key]
267
+ else:
268
+ current_value = getattr(current_value, key)
269
+
270
+ # Set the value
271
+ if isinstance(current_value, MutableMapping):
272
+ current_value[rest_keys[-1]] = value
273
+ else:
274
+ setattr(current_value, rest_keys[-1], value)
275
+
276
+ @override
277
+ def __delitem__(self, key: str):
278
+ # This is unsupported for this class
279
+ raise NotImplementedError
280
+
281
+ @override
282
+ def __iter__(self):
283
+ return iter(self._ll_dict)
284
+
285
+ @override
286
+ def __len__(self):
287
+ return len(self._ll_dict)
288
+
289
+ # endregion
@@ -0,0 +1,4 @@
1
+ from . import transform as dataset_transform
2
+ from .balanced_batch_sampler import BalancedBatchSampler as BalancedBatchSampler
3
+
4
+ _ = dataset_transform