nshtrainer 1.3.5__py3-none-any.whl → 1.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +14 -0
- nshtrainer/configs/__init__.py +1 -5
- nshtrainer/configs/trainer/__init__.py +4 -2
- nshtrainer/configs/trainer/_config/__init__.py +4 -2
- nshtrainer/trainer/_config.py +517 -71
- nshtrainer/trainer/trainer.py +1 -0
- {nshtrainer-1.3.5.dist-info → nshtrainer-1.3.6.dist-info}/METADATA +1 -1
- {nshtrainer-1.3.5.dist-info → nshtrainer-1.3.6.dist-info}/RECORD +9 -11
- nshtrainer/_directory.py +0 -72
- nshtrainer/configs/_directory/__init__.py +0 -15
- {nshtrainer-1.3.5.dist-info → nshtrainer-1.3.6.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py
CHANGED
@@ -19,3 +19,17 @@ try:
|
|
19
19
|
from . import configs as configs
|
20
20
|
except BaseException:
|
21
21
|
pass
|
22
|
+
|
23
|
+
try:
|
24
|
+
from importlib.metadata import PackageNotFoundError, version
|
25
|
+
except ImportError:
|
26
|
+
# For Python <3.8
|
27
|
+
from importlib_metadata import ( # pyright: ignore[reportMissingImports]
|
28
|
+
PackageNotFoundError,
|
29
|
+
version,
|
30
|
+
)
|
31
|
+
|
32
|
+
try:
|
33
|
+
__version__ = version(__name__)
|
34
|
+
except PackageNotFoundError:
|
35
|
+
__version__ = "unknown"
|
nshtrainer/configs/__init__.py
CHANGED
@@ -5,7 +5,6 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer import MetricConfig as MetricConfig
|
6
6
|
from nshtrainer import TrainerConfig as TrainerConfig
|
7
7
|
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
8
|
-
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
9
8
|
from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
|
10
9
|
from nshtrainer._hf_hub import (
|
11
10
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
@@ -126,9 +125,9 @@ from nshtrainer.trainer._config import (
|
|
126
125
|
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
127
126
|
)
|
128
127
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
128
|
+
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
129
129
|
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
130
130
|
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
131
|
-
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
132
131
|
from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
133
132
|
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
134
133
|
from nshtrainer.trainer.accelerator import (
|
@@ -227,7 +226,6 @@ from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
|
227
226
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
228
227
|
|
229
228
|
from . import _checkpoint as _checkpoint
|
230
|
-
from . import _directory as _directory
|
231
229
|
from . import _hf_hub as _hf_hub
|
232
230
|
from . import callbacks as callbacks
|
233
231
|
from . import loggers as loggers
|
@@ -338,7 +336,6 @@ __all__ = [
|
|
338
336
|
"RpropConfig",
|
339
337
|
"SGDConfig",
|
340
338
|
"SLURMEnvironmentPlugin",
|
341
|
-
"SanityCheckingConfig",
|
342
339
|
"SharedParametersCallbackConfig",
|
343
340
|
"SiLUNonlinearityConfig",
|
344
341
|
"SigmoidNonlinearityConfig",
|
@@ -367,7 +364,6 @@ __all__ = [
|
|
367
364
|
"XLAEnvironmentPlugin",
|
368
365
|
"XLAPluginConfig",
|
369
366
|
"_checkpoint",
|
370
|
-
"_directory",
|
371
367
|
"_hf_hub",
|
372
368
|
"accelerator_registry",
|
373
369
|
"callback_registry",
|
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
|
|
22
22
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
23
23
|
)
|
24
24
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
25
|
+
from nshtrainer.trainer._config import (
|
26
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
27
|
+
)
|
25
28
|
from nshtrainer.trainer._config import (
|
26
29
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
27
30
|
)
|
@@ -51,7 +54,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
|
51
54
|
from nshtrainer.trainer._config import (
|
52
55
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
53
56
|
)
|
54
|
-
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
55
57
|
from nshtrainer.trainer._config import (
|
56
58
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
57
59
|
)
|
@@ -152,6 +154,7 @@ __all__ = [
|
|
152
154
|
"DebugFlagCallbackConfig",
|
153
155
|
"DeepSpeedPluginConfig",
|
154
156
|
"DirectoryConfig",
|
157
|
+
"DirectorySetupCallbackConfig",
|
155
158
|
"DistributedPredictionWriterConfig",
|
156
159
|
"DoublePrecisionPluginConfig",
|
157
160
|
"EarlyStoppingCallbackConfig",
|
@@ -180,7 +183,6 @@ __all__ = [
|
|
180
183
|
"ProfilerConfig",
|
181
184
|
"RLPSanityChecksCallbackConfig",
|
182
185
|
"SLURMEnvironmentPlugin",
|
183
|
-
"SanityCheckingConfig",
|
184
186
|
"SharedParametersCallbackConfig",
|
185
187
|
"StrategyConfig",
|
186
188
|
"StrategyConfigBase",
|
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
|
|
18
18
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
19
19
|
)
|
20
20
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
21
|
+
from nshtrainer.trainer._config import (
|
22
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
23
|
+
)
|
21
24
|
from nshtrainer.trainer._config import (
|
22
25
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
23
26
|
)
|
@@ -48,7 +51,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
|
48
51
|
from nshtrainer.trainer._config import (
|
49
52
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
50
53
|
)
|
51
|
-
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
52
54
|
from nshtrainer.trainer._config import (
|
53
55
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
54
56
|
)
|
@@ -70,6 +72,7 @@ __all__ = [
|
|
70
72
|
"CheckpointSavingConfig",
|
71
73
|
"DebugFlagCallbackConfig",
|
72
74
|
"DirectoryConfig",
|
75
|
+
"DirectorySetupCallbackConfig",
|
73
76
|
"EarlyStoppingCallbackConfig",
|
74
77
|
"EnvironmentConfig",
|
75
78
|
"GradientClippingConfig",
|
@@ -86,7 +89,6 @@ __all__ = [
|
|
86
89
|
"PluginConfig",
|
87
90
|
"ProfilerConfig",
|
88
91
|
"RLPSanityChecksCallbackConfig",
|
89
|
-
"SanityCheckingConfig",
|
90
92
|
"SharedParametersCallbackConfig",
|
91
93
|
"StrategyConfig",
|
92
94
|
"TensorboardLoggerConfig",
|
nshtrainer/trainer/_config.py
CHANGED
@@ -26,7 +26,6 @@ from lightning.pytorch.profilers import Profiler
|
|
26
26
|
from lightning.pytorch.strategies.strategy import Strategy
|
27
27
|
from typing_extensions import TypeAliasType, TypedDict, override
|
28
28
|
|
29
|
-
from .._directory import DirectoryConfig
|
30
29
|
from .._hf_hub import HuggingFaceHubConfig
|
31
30
|
from ..callbacks import (
|
32
31
|
BestCheckpointCallbackConfig,
|
@@ -38,6 +37,7 @@ from ..callbacks import (
|
|
38
37
|
)
|
39
38
|
from ..callbacks.base import CallbackConfigBase
|
40
39
|
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
40
|
+
from ..callbacks.directory_setup import DirectorySetupCallbackConfig
|
41
41
|
from ..callbacks.log_epoch import LogEpochCallbackConfig
|
42
42
|
from ..callbacks.lr_monitor import LearningRateMonitorConfig
|
43
43
|
from ..callbacks.metric_validation import MetricValidationCallbackConfig
|
@@ -352,19 +352,74 @@ class LightningTrainerKwargs(TypedDict, total=False):
|
|
352
352
|
"""
|
353
353
|
|
354
354
|
|
355
|
-
|
356
|
-
|
355
|
+
DEFAULT_LOGDIR_BASENAME = "nshtrainer_logs"
|
356
|
+
"""Default base name for the log directory."""
|
357
|
+
|
358
|
+
|
359
|
+
class DirectoryConfig(C.Config):
|
360
|
+
project_root: Path | None = None
|
357
361
|
"""
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
Valid values are: "disable", "warn", "error".
|
362
|
+
Root directory for this project.
|
363
|
+
|
364
|
+
This isn't specific to the current run; it is the parent directory of all runs.
|
362
365
|
"""
|
363
366
|
|
367
|
+
logdir_basename: str = DEFAULT_LOGDIR_BASENAME
|
368
|
+
"""Base name for the log directory."""
|
369
|
+
|
370
|
+
setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
|
371
|
+
"""Configuration for the directory setup PyTorch Lightning callback."""
|
372
|
+
|
373
|
+
def resolve_run_root_directory(self, run_id: str) -> Path:
|
374
|
+
if (project_root_dir := self.project_root) is None:
|
375
|
+
project_root_dir = Path.cwd()
|
376
|
+
|
377
|
+
# The default base dir is $CWD/{logdir_basename}/{id}/
|
378
|
+
base_dir = project_root_dir / self.logdir_basename
|
379
|
+
base_dir.mkdir(exist_ok=True)
|
380
|
+
|
381
|
+
# Add a .gitignore file to the {logdir_basename} directory
|
382
|
+
# which will ignore all files except for the .gitignore file itself
|
383
|
+
gitignore_path = base_dir / ".gitignore"
|
384
|
+
if not gitignore_path.exists():
|
385
|
+
gitignore_path.touch()
|
386
|
+
gitignore_path.write_text("*\n")
|
387
|
+
|
388
|
+
base_dir = base_dir / run_id
|
389
|
+
base_dir.mkdir(exist_ok=True)
|
390
|
+
|
391
|
+
return base_dir
|
392
|
+
|
393
|
+
def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
|
394
|
+
# The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
|
395
|
+
if (subdir := getattr(self, subdirectory, None)) is not None:
|
396
|
+
assert isinstance(subdir, Path), (
|
397
|
+
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
398
|
+
)
|
399
|
+
return subdir
|
400
|
+
|
401
|
+
dir = self.resolve_run_root_directory(run_id)
|
402
|
+
dir = dir / subdirectory
|
403
|
+
dir.mkdir(exist_ok=True)
|
404
|
+
return dir
|
405
|
+
|
406
|
+
def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
|
407
|
+
if (log_dir := logger.log_dir) is not None:
|
408
|
+
return log_dir
|
409
|
+
|
410
|
+
# Save to {logdir_basename}/{id}/log/{logger name}
|
411
|
+
log_dir = self.resolve_subdirectory(run_id, "log")
|
412
|
+
log_dir = log_dir / logger.resolve_logger_dirname()
|
413
|
+
# ^ NOTE: Logger must have a `name` attribute, as this is
|
414
|
+
# the discriminator for the logger registry
|
415
|
+
log_dir.mkdir(exist_ok=True)
|
416
|
+
|
417
|
+
return log_dir
|
418
|
+
|
364
419
|
|
365
420
|
class TrainerConfig(C.Config):
|
366
421
|
# region Active Run Configuration
|
367
|
-
id: str
|
422
|
+
id: Annotated[str, C.AllowMissing()] = C.MISSING
|
368
423
|
"""ID of the run."""
|
369
424
|
name: list[str] = []
|
370
425
|
"""Run name in parts. Full name is constructed by joining the parts with spaces."""
|
@@ -393,39 +448,6 @@ class TrainerConfig(C.Config):
|
|
393
448
|
|
394
449
|
directory: DirectoryConfig = DirectoryConfig()
|
395
450
|
"""Directory configuration options."""
|
396
|
-
|
397
|
-
_rng: ClassVar[np.random.Generator | None] = None
|
398
|
-
|
399
|
-
@classmethod
|
400
|
-
def generate_id(cls, *, length: int = 8) -> str:
|
401
|
-
"""
|
402
|
-
Generate a random ID of specified length.
|
403
|
-
|
404
|
-
"""
|
405
|
-
if (rng := cls._rng) is None:
|
406
|
-
rng = np.random.default_rng()
|
407
|
-
|
408
|
-
alphabet = list(string.ascii_lowercase + string.digits)
|
409
|
-
|
410
|
-
id = "".join(rng.choice(alphabet) for _ in range(length))
|
411
|
-
return id
|
412
|
-
|
413
|
-
@classmethod
|
414
|
-
def set_seed(cls, seed: int | None = None) -> None:
|
415
|
-
"""
|
416
|
-
Set the seed for the random number generator.
|
417
|
-
|
418
|
-
Args:
|
419
|
-
seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
|
420
|
-
|
421
|
-
Returns:
|
422
|
-
None
|
423
|
-
"""
|
424
|
-
if seed is None:
|
425
|
-
seed = int(time.time() * 1000)
|
426
|
-
log.critical(f"Seeding {cls.__name__} with seed {seed}")
|
427
|
-
cls._rng = np.random.default_rng(seed)
|
428
|
-
|
429
451
|
# endregion
|
430
452
|
|
431
453
|
primary_metric: MetricConfig | None = None
|
@@ -755,40 +777,40 @@ class TrainerConfig(C.Config):
|
|
755
777
|
None,
|
756
778
|
)
|
757
779
|
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
yield self.lr_monitor
|
763
|
-
yield from (
|
764
|
-
logger_config
|
765
|
-
for logger_config in self.enabled_loggers()
|
766
|
-
if logger_config is not None
|
767
|
-
and isinstance(logger_config, CallbackConfigBase)
|
768
|
-
)
|
769
|
-
yield self.log_epoch
|
770
|
-
yield self.log_norms
|
771
|
-
yield self.hf_hub
|
772
|
-
yield self.shared_parameters
|
773
|
-
yield self.reduce_lr_on_plateau_sanity_checking
|
774
|
-
yield self.auto_set_debug_flag
|
775
|
-
yield self.auto_validate_metrics
|
776
|
-
yield from self.callbacks
|
780
|
+
# region Helper Methods
|
781
|
+
def id_(self, value: str):
|
782
|
+
"""
|
783
|
+
Set the id for the trainer configuration in-place.
|
777
784
|
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
785
|
+
Parameters
|
786
|
+
----------
|
787
|
+
value : str
|
788
|
+
The id value to set
|
782
789
|
|
783
|
-
|
784
|
-
|
790
|
+
Returns
|
791
|
+
-------
|
792
|
+
self
|
793
|
+
Returns self for method chaining
|
794
|
+
"""
|
795
|
+
self.id = value
|
796
|
+
return self
|
785
797
|
|
786
|
-
def
|
787
|
-
|
788
|
-
|
789
|
-
|
798
|
+
def with_id(self, value: str):
|
799
|
+
"""
|
800
|
+
Create a copy of the current configuration with an updated id.
|
801
|
+
|
802
|
+
Parameters
|
803
|
+
----------
|
804
|
+
value : str
|
805
|
+
The id value to set
|
806
|
+
|
807
|
+
Returns
|
808
|
+
-------
|
809
|
+
TrainerConfig
|
810
|
+
A new instance of the configuration with the updated id
|
811
|
+
"""
|
812
|
+
return copy.deepcopy(self).id_(value)
|
790
813
|
|
791
|
-
# region Helper Methods
|
792
814
|
def fast_dev_run_(self, value: int | bool = True, /):
|
793
815
|
"""
|
794
816
|
Enables fast_dev_run mode for the trainer.
|
@@ -831,6 +853,349 @@ class TrainerConfig(C.Config):
|
|
831
853
|
"""
|
832
854
|
return copy.deepcopy(self).project_root_(project_root)
|
833
855
|
|
856
|
+
def name_(self, *parts: str):
|
857
|
+
"""
|
858
|
+
Set the name for the trainer configuration in-place.
|
859
|
+
|
860
|
+
Parameters
|
861
|
+
----------
|
862
|
+
*parts : str
|
863
|
+
The parts of the name to set. Will be joined with spaces.
|
864
|
+
|
865
|
+
Returns
|
866
|
+
-------
|
867
|
+
self
|
868
|
+
Returns self for method chaining
|
869
|
+
"""
|
870
|
+
self.name = list(parts)
|
871
|
+
return self
|
872
|
+
|
873
|
+
def with_name(self, *parts: str):
|
874
|
+
"""
|
875
|
+
Create a copy of the current configuration with an updated name.
|
876
|
+
|
877
|
+
Parameters
|
878
|
+
----------
|
879
|
+
*parts : str
|
880
|
+
The parts of the name to set. Will be joined with spaces.
|
881
|
+
|
882
|
+
Returns
|
883
|
+
-------
|
884
|
+
TrainerConfig
|
885
|
+
A new instance of the configuration with the updated name
|
886
|
+
"""
|
887
|
+
return copy.deepcopy(self).name_(*parts)
|
888
|
+
|
889
|
+
def project_(self, project: str | None):
|
890
|
+
"""
|
891
|
+
Set the project name for the trainer configuration in-place.
|
892
|
+
|
893
|
+
Parameters
|
894
|
+
----------
|
895
|
+
project : str | None
|
896
|
+
The project name to set
|
897
|
+
|
898
|
+
Returns
|
899
|
+
-------
|
900
|
+
self
|
901
|
+
Returns self for method chaining
|
902
|
+
"""
|
903
|
+
self.project = project
|
904
|
+
return self
|
905
|
+
|
906
|
+
def with_project(self, project: str | None):
|
907
|
+
"""
|
908
|
+
Create a copy of the current configuration with an updated project name.
|
909
|
+
|
910
|
+
Parameters
|
911
|
+
----------
|
912
|
+
project : str | None
|
913
|
+
The project name to set
|
914
|
+
|
915
|
+
Returns
|
916
|
+
-------
|
917
|
+
TrainerConfig
|
918
|
+
A new instance of the configuration with the updated project name
|
919
|
+
"""
|
920
|
+
return copy.deepcopy(self).project_(project)
|
921
|
+
|
922
|
+
def tags_(self, *tags: str):
|
923
|
+
"""
|
924
|
+
Set the tags for the trainer configuration in-place.
|
925
|
+
|
926
|
+
Parameters
|
927
|
+
----------
|
928
|
+
*tags : str
|
929
|
+
The tags to set
|
930
|
+
|
931
|
+
Returns
|
932
|
+
-------
|
933
|
+
self
|
934
|
+
Returns self for method chaining
|
935
|
+
"""
|
936
|
+
self.tags = list(tags)
|
937
|
+
return self
|
938
|
+
|
939
|
+
def with_tags(self, *tags: str):
|
940
|
+
"""
|
941
|
+
Create a copy of the current configuration with updated tags.
|
942
|
+
|
943
|
+
Parameters
|
944
|
+
----------
|
945
|
+
*tags : str
|
946
|
+
The tags to set
|
947
|
+
|
948
|
+
Returns
|
949
|
+
-------
|
950
|
+
TrainerConfig
|
951
|
+
A new instance of the configuration with the updated tags
|
952
|
+
"""
|
953
|
+
return copy.deepcopy(self).tags_(*tags)
|
954
|
+
|
955
|
+
def add_tags_(self, *tags: str):
|
956
|
+
"""
|
957
|
+
Add tags to the trainer configuration in-place.
|
958
|
+
|
959
|
+
Parameters
|
960
|
+
----------
|
961
|
+
*tags : str
|
962
|
+
The tags to add
|
963
|
+
|
964
|
+
Returns
|
965
|
+
-------
|
966
|
+
self
|
967
|
+
Returns self for method chaining
|
968
|
+
"""
|
969
|
+
self.tags.extend(tags)
|
970
|
+
return self
|
971
|
+
|
972
|
+
def with_added_tags(self, *tags: str):
|
973
|
+
"""
|
974
|
+
Create a copy of the current configuration with additional tags.
|
975
|
+
|
976
|
+
Parameters
|
977
|
+
----------
|
978
|
+
*tags : str
|
979
|
+
The tags to add
|
980
|
+
|
981
|
+
Returns
|
982
|
+
-------
|
983
|
+
TrainerConfig
|
984
|
+
A new instance of the configuration with the additional tags
|
985
|
+
"""
|
986
|
+
return copy.deepcopy(self).add_tags_(*tags)
|
987
|
+
|
988
|
+
def notes_(self, *notes: str):
|
989
|
+
"""
|
990
|
+
Set the notes for the trainer configuration in-place.
|
991
|
+
|
992
|
+
Parameters
|
993
|
+
----------
|
994
|
+
*notes : str
|
995
|
+
The notes to set
|
996
|
+
|
997
|
+
Returns
|
998
|
+
-------
|
999
|
+
self
|
1000
|
+
Returns self for method chaining
|
1001
|
+
"""
|
1002
|
+
self.notes = list(notes)
|
1003
|
+
return self
|
1004
|
+
|
1005
|
+
def with_notes(self, *notes: str):
|
1006
|
+
"""
|
1007
|
+
Create a copy of the current configuration with updated notes.
|
1008
|
+
|
1009
|
+
Parameters
|
1010
|
+
----------
|
1011
|
+
*notes : str
|
1012
|
+
The notes to set
|
1013
|
+
|
1014
|
+
Returns
|
1015
|
+
-------
|
1016
|
+
TrainerConfig
|
1017
|
+
A new instance of the configuration with the updated notes
|
1018
|
+
"""
|
1019
|
+
return copy.deepcopy(self).notes_(*notes)
|
1020
|
+
|
1021
|
+
def add_notes_(self, *notes: str):
|
1022
|
+
"""
|
1023
|
+
Add notes to the trainer configuration in-place.
|
1024
|
+
|
1025
|
+
Parameters
|
1026
|
+
----------
|
1027
|
+
*notes : str
|
1028
|
+
The notes to add
|
1029
|
+
|
1030
|
+
Returns
|
1031
|
+
-------
|
1032
|
+
self
|
1033
|
+
Returns self for method chaining
|
1034
|
+
"""
|
1035
|
+
self.notes.extend(notes)
|
1036
|
+
return self
|
1037
|
+
|
1038
|
+
def with_added_notes(self, *notes: str):
|
1039
|
+
"""
|
1040
|
+
Create a copy of the current configuration with additional notes.
|
1041
|
+
|
1042
|
+
Parameters
|
1043
|
+
----------
|
1044
|
+
*notes : str
|
1045
|
+
The notes to add
|
1046
|
+
|
1047
|
+
Returns
|
1048
|
+
-------
|
1049
|
+
TrainerConfig
|
1050
|
+
A new instance of the configuration with the additional notes
|
1051
|
+
"""
|
1052
|
+
return copy.deepcopy(self).add_notes_(*notes)
|
1053
|
+
|
1054
|
+
def meta_(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
|
1055
|
+
"""
|
1056
|
+
Update the `meta` dictionary in-place with the provided key-value pairs.
|
1057
|
+
|
1058
|
+
This method allows updating the meta information associated with the trainer
|
1059
|
+
configuration by either passing a dictionary or keyword arguments.
|
1060
|
+
|
1061
|
+
Parameters
|
1062
|
+
----------
|
1063
|
+
meta : dict[str, Any] | None, optional
|
1064
|
+
A dictionary containing meta information to be added, by default None
|
1065
|
+
**kwargs : Any
|
1066
|
+
Additional key-value pairs to be added to the meta dictionary
|
1067
|
+
|
1068
|
+
Returns
|
1069
|
+
-------
|
1070
|
+
self
|
1071
|
+
Returns self for method chaining
|
1072
|
+
"""
|
1073
|
+
if meta is not None:
|
1074
|
+
self.meta.update(meta)
|
1075
|
+
self.meta.update(kwargs)
|
1076
|
+
return self
|
1077
|
+
|
1078
|
+
def with_meta(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
|
1079
|
+
"""
|
1080
|
+
Create a copy of the current configuration with updated meta information.
|
1081
|
+
|
1082
|
+
This method is similar to `meta_`, but it returns a new instance of the configuration
|
1083
|
+
with the updated meta information instead of modifying the current instance.
|
1084
|
+
|
1085
|
+
Parameters
|
1086
|
+
----------
|
1087
|
+
meta : dict[str, Any] | None, optional
|
1088
|
+
A dictionary containing meta information to be added, by default None
|
1089
|
+
**kwargs : Any
|
1090
|
+
Additional key-value pairs to be added to the meta dictionary
|
1091
|
+
|
1092
|
+
Returns
|
1093
|
+
-------
|
1094
|
+
TrainerConfig
|
1095
|
+
A new instance of the configuration with updated meta information
|
1096
|
+
"""
|
1097
|
+
|
1098
|
+
return self.model_copy(deep=True).meta_(meta, **kwargs)
|
1099
|
+
|
1100
|
+
def debug_(self, value: bool = True):
|
1101
|
+
"""
|
1102
|
+
Set the debug flag for the trainer configuration in-place.
|
1103
|
+
|
1104
|
+
Parameters
|
1105
|
+
----------
|
1106
|
+
value : bool, optional
|
1107
|
+
The debug flag value to set, by default True
|
1108
|
+
|
1109
|
+
Returns
|
1110
|
+
-------
|
1111
|
+
self
|
1112
|
+
Returns self for method chaining
|
1113
|
+
"""
|
1114
|
+
self.debug = value
|
1115
|
+
return self
|
1116
|
+
|
1117
|
+
def with_debug(self, value: bool = True):
|
1118
|
+
"""
|
1119
|
+
Create a copy of the current configuration with an updated debug flag.
|
1120
|
+
|
1121
|
+
Parameters
|
1122
|
+
----------
|
1123
|
+
value : bool, optional
|
1124
|
+
The debug flag value to set, by default True
|
1125
|
+
|
1126
|
+
Returns
|
1127
|
+
-------
|
1128
|
+
TrainerConfig
|
1129
|
+
A new instance of the configuration with the updated debug flag
|
1130
|
+
"""
|
1131
|
+
return copy.deepcopy(self).debug_(value)
|
1132
|
+
|
1133
|
+
def ckpt_path_(self, path: Literal["none"] | str | Path | None):
|
1134
|
+
"""
|
1135
|
+
Set the checkpoint path for the trainer configuration in-place.
|
1136
|
+
|
1137
|
+
Parameters
|
1138
|
+
----------
|
1139
|
+
path : Literal["none"] | str | Path | None
|
1140
|
+
The checkpoint path to set
|
1141
|
+
|
1142
|
+
Returns
|
1143
|
+
-------
|
1144
|
+
self
|
1145
|
+
Returns self for method chaining
|
1146
|
+
"""
|
1147
|
+
self.ckpt_path = path
|
1148
|
+
return self
|
1149
|
+
|
1150
|
+
def with_ckpt_path(self, path: Literal["none"] | str | Path | None):
|
1151
|
+
"""
|
1152
|
+
Create a copy of the current configuration with an updated checkpoint path.
|
1153
|
+
|
1154
|
+
Parameters
|
1155
|
+
----------
|
1156
|
+
path : Literal["none"] | str | Path | None
|
1157
|
+
The checkpoint path to set
|
1158
|
+
|
1159
|
+
Returns
|
1160
|
+
-------
|
1161
|
+
TrainerConfig
|
1162
|
+
A new instance of the configuration with the updated checkpoint path
|
1163
|
+
"""
|
1164
|
+
return copy.deepcopy(self).ckpt_path_(path)
|
1165
|
+
|
1166
|
+
def barebones_(self, value: bool = True):
|
1167
|
+
"""
|
1168
|
+
Set the barebones flag for the trainer configuration in-place.
|
1169
|
+
|
1170
|
+
Parameters
|
1171
|
+
----------
|
1172
|
+
value : bool, optional
|
1173
|
+
The barebones flag value to set, by default True
|
1174
|
+
|
1175
|
+
Returns
|
1176
|
+
-------
|
1177
|
+
self
|
1178
|
+
Returns self for method chaining
|
1179
|
+
"""
|
1180
|
+
self.barebones = value
|
1181
|
+
return self
|
1182
|
+
|
1183
|
+
def with_barebones(self, value: bool = True):
|
1184
|
+
"""
|
1185
|
+
Create a copy of the current configuration with an updated barebones flag.
|
1186
|
+
|
1187
|
+
Parameters
|
1188
|
+
----------
|
1189
|
+
value : bool, optional
|
1190
|
+
The barebones flag value to set, by default True
|
1191
|
+
|
1192
|
+
Returns
|
1193
|
+
-------
|
1194
|
+
TrainerConfig
|
1195
|
+
A new instance of the configuration with the updated barebones flag
|
1196
|
+
"""
|
1197
|
+
return copy.deepcopy(self).barebones_(value)
|
1198
|
+
|
834
1199
|
def reset_run(
|
835
1200
|
self,
|
836
1201
|
*,
|
@@ -873,3 +1238,84 @@ class TrainerConfig(C.Config):
|
|
873
1238
|
return config
|
874
1239
|
|
875
1240
|
# endregion
|
1241
|
+
|
1242
|
+
# region Random ID Generation
|
1243
|
+
_rng: ClassVar[np.random.Generator | None] = None
|
1244
|
+
|
1245
|
+
@classmethod
|
1246
|
+
def generate_id(cls, *, length: int = 8) -> str:
|
1247
|
+
"""
|
1248
|
+
Generate a random ID of specified length.
|
1249
|
+
|
1250
|
+
"""
|
1251
|
+
if (rng := cls._rng) is None:
|
1252
|
+
rng = np.random.default_rng()
|
1253
|
+
|
1254
|
+
alphabet = list(string.ascii_lowercase + string.digits)
|
1255
|
+
|
1256
|
+
id = "".join(rng.choice(alphabet) for _ in range(length))
|
1257
|
+
return id
|
1258
|
+
|
1259
|
+
@classmethod
|
1260
|
+
def set_seed(cls, seed: int | None = None) -> None:
|
1261
|
+
"""
|
1262
|
+
Set the seed for the random number generator.
|
1263
|
+
|
1264
|
+
Args:
|
1265
|
+
seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
|
1266
|
+
|
1267
|
+
Returns:
|
1268
|
+
None
|
1269
|
+
"""
|
1270
|
+
if seed is None:
|
1271
|
+
seed = int(time.time() * 1000)
|
1272
|
+
log.critical(f"Seeding {cls.__name__} with seed {seed}")
|
1273
|
+
cls._rng = np.random.default_rng(seed)
|
1274
|
+
|
1275
|
+
# endregion
|
1276
|
+
|
1277
|
+
# region Internal Methods
|
1278
|
+
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
1279
|
+
yield self.directory.setup_callback
|
1280
|
+
yield self.early_stopping
|
1281
|
+
yield self.checkpoint_saving
|
1282
|
+
yield self.lr_monitor
|
1283
|
+
yield from (
|
1284
|
+
logger_config
|
1285
|
+
for logger_config in self.enabled_loggers()
|
1286
|
+
if logger_config is not None
|
1287
|
+
and isinstance(logger_config, CallbackConfigBase)
|
1288
|
+
)
|
1289
|
+
yield self.log_epoch
|
1290
|
+
yield self.log_norms
|
1291
|
+
yield self.hf_hub
|
1292
|
+
yield self.shared_parameters
|
1293
|
+
yield self.reduce_lr_on_plateau_sanity_checking
|
1294
|
+
yield self.auto_set_debug_flag
|
1295
|
+
yield self.auto_validate_metrics
|
1296
|
+
yield from self.callbacks
|
1297
|
+
|
1298
|
+
def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
|
1299
|
+
# Disable all loggers if barebones mode is enabled
|
1300
|
+
if self.barebones:
|
1301
|
+
return
|
1302
|
+
|
1303
|
+
yield from self.enabled_loggers()
|
1304
|
+
yield self.actsave_logger
|
1305
|
+
|
1306
|
+
def _nshtrainer_validate_before_run(self):
|
1307
|
+
# shared_parameters is not supported under barebones mode
|
1308
|
+
if self.barebones and self.shared_parameters:
|
1309
|
+
raise ValueError("shared_parameters is not supported under barebones mode")
|
1310
|
+
|
1311
|
+
def _nshtrainer_set_id_if_missing(self):
|
1312
|
+
"""
|
1313
|
+
Set the ID for the configuration object if it is missing.
|
1314
|
+
"""
|
1315
|
+
if self.id is C.MISSING:
|
1316
|
+
self.id = self.generate_id()
|
1317
|
+
log.info(f"TrainerConfig's run ID is missing, setting to {self.id}.")
|
1318
|
+
else:
|
1319
|
+
log.debug(f"TrainerConfig's run ID is already set to {self.id}.")
|
1320
|
+
|
1321
|
+
# endregion
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -316,6 +316,7 @@ class Trainer(LightningTrainer):
|
|
316
316
|
f"Trainer hparams must either be an instance of {hparams_cls} or a mapping. "
|
317
317
|
f"Got {type(hparams)=} instead."
|
318
318
|
)
|
319
|
+
hparams._nshtrainer_set_id_if_missing()
|
319
320
|
hparams = hparams.model_deep_validate()
|
320
321
|
hparams._nshtrainer_validate_before_run()
|
321
322
|
|
@@ -1,9 +1,8 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
|
-
nshtrainer/__init__.py,sha256=
|
2
|
+
nshtrainer/__init__.py,sha256=RI_2B_IUWa10B6H5TAuWtE5FWX1X4ue-J4dTDaF2-lQ,1035
|
3
3
|
nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
|
6
|
-
nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
|
7
6
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
7
|
nshtrainer/_hf_hub.py,sha256=kfN0wDxK5JWKKGZnX_706i0KXGhaS19p581LDTPxlRE,13996
|
9
8
|
nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
|
@@ -33,10 +32,9 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
|
|
33
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=4X-mpiX5ghj9vnEreK2i8Xyvimqt0K-PNWA2HtT-B6I,1940
|
34
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
35
34
|
nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
|
36
|
-
nshtrainer/configs/__init__.py,sha256
|
35
|
+
nshtrainer/configs/__init__.py,sha256=-yJ5Uk9VkANqfk-QnX2aynL0jSf7cJQuQNzT1GAE1x8,15684
|
37
36
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
38
37
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
39
|
-
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
40
38
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
|
41
39
|
nshtrainer/configs/callbacks/__init__.py,sha256=tP9urR73NIanyxpbi4EERsxOnGNiptbQpmsj-v53a38,4774
|
42
40
|
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
|
@@ -85,8 +83,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
85
83
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
86
84
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
87
85
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
88
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
89
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
86
|
+
nshtrainer/configs/trainer/__init__.py,sha256=DM2PlB4WRDZ_dqEeW91LbKRFa4sIF_pETU0T9GYJ5-g,8073
|
87
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=z5UpuXktBanLOYNkkbgbbHE06iQtcSuAKTpnx2TLmCo,3850
|
90
88
|
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
91
89
|
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
92
90
|
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
|
@@ -135,7 +133,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
|
|
135
133
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
136
134
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
137
135
|
nshtrainer/trainer/__init__.py,sha256=jRaHdaFK8wxNrN1bleT9cf29iZahL_-XkWo5TWz2CmA,550
|
138
|
-
nshtrainer/trainer/_config.py,sha256=
|
136
|
+
nshtrainer/trainer/_config.py,sha256=x3YjP_0IykqSRh8YIilCxq3nPt_fZXoVcxfR13ulmV0,45578
|
139
137
|
nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
|
140
138
|
nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
|
141
139
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
@@ -148,7 +146,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
|
|
148
146
|
nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
|
149
147
|
nshtrainer/trainer/signal_connector.py,sha256=ZgbSkbthoe8MYN6rBoFf-7UDpQtc9fs9pG_FNvTYSfs,10962
|
150
148
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
151
|
-
nshtrainer/trainer/trainer.py,sha256=
|
149
|
+
nshtrainer/trainer/trainer.py,sha256=iQWu0KfwLY-1q9EEsg0xPlyUN1fsJ9iXfSQbPmiMlac,24177
|
152
150
|
nshtrainer/util/_environment_info.py,sha256=j-wyEHKirsu3rIXTtqC2kLmIIkRe6obWjxPVWaqg2ow,24887
|
153
151
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
154
152
|
nshtrainer/util/code_upload.py,sha256=CpbZEBbA8EcBElUVoCPbP5zdwtNzJhS20RLaOB-q-2k,1257
|
@@ -161,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
161
159
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
162
160
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
163
161
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
164
|
-
nshtrainer-1.3.
|
165
|
-
nshtrainer-1.3.
|
166
|
-
nshtrainer-1.3.
|
162
|
+
nshtrainer-1.3.6.dist-info/METADATA,sha256=KidLM7J5P7mALTfYQsveSDjAOJZ-Gcq4_e1-Xgrms68,979
|
163
|
+
nshtrainer-1.3.6.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
164
|
+
nshtrainer-1.3.6.dist-info/RECORD,,
|
nshtrainer/_directory.py
DELETED
@@ -1,72 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
import nshconfig as C
|
7
|
-
|
8
|
-
from .callbacks.directory_setup import DirectorySetupCallbackConfig
|
9
|
-
from .loggers import LoggerConfig
|
10
|
-
|
11
|
-
log = logging.getLogger(__name__)
|
12
|
-
|
13
|
-
|
14
|
-
class DirectoryConfig(C.Config):
|
15
|
-
project_root: Path | None = None
|
16
|
-
"""
|
17
|
-
Root directory for this project.
|
18
|
-
|
19
|
-
This isn't specific to the run; it is the parent directory of all runs.
|
20
|
-
"""
|
21
|
-
|
22
|
-
logdir_basename: str = "nshtrainer"
|
23
|
-
"""Base name for the log directory."""
|
24
|
-
|
25
|
-
setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
|
26
|
-
"""Configuration for the directory setup PyTorch Lightning callback."""
|
27
|
-
|
28
|
-
def resolve_run_root_directory(self, run_id: str) -> Path:
|
29
|
-
if (project_root_dir := self.project_root) is None:
|
30
|
-
project_root_dir = Path.cwd()
|
31
|
-
|
32
|
-
# The default base dir is $CWD/{logdir_basename}/{id}/
|
33
|
-
base_dir = project_root_dir / self.logdir_basename
|
34
|
-
base_dir.mkdir(exist_ok=True)
|
35
|
-
|
36
|
-
# Add a .gitignore file to the {logdir_basename} directory
|
37
|
-
# which will ignore all files except for the .gitignore file itself
|
38
|
-
gitignore_path = base_dir / ".gitignore"
|
39
|
-
if not gitignore_path.exists():
|
40
|
-
gitignore_path.touch()
|
41
|
-
gitignore_path.write_text("*\n")
|
42
|
-
|
43
|
-
base_dir = base_dir / run_id
|
44
|
-
base_dir.mkdir(exist_ok=True)
|
45
|
-
|
46
|
-
return base_dir
|
47
|
-
|
48
|
-
def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
|
49
|
-
# The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
|
50
|
-
if (subdir := getattr(self, subdirectory, None)) is not None:
|
51
|
-
assert isinstance(subdir, Path), (
|
52
|
-
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
53
|
-
)
|
54
|
-
return subdir
|
55
|
-
|
56
|
-
dir = self.resolve_run_root_directory(run_id)
|
57
|
-
dir = dir / subdirectory
|
58
|
-
dir.mkdir(exist_ok=True)
|
59
|
-
return dir
|
60
|
-
|
61
|
-
def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
|
62
|
-
if (log_dir := logger.log_dir) is not None:
|
63
|
-
return log_dir
|
64
|
-
|
65
|
-
# Save to {logdir_basename}/{id}/log/{logger name}
|
66
|
-
log_dir = self.resolve_subdirectory(run_id, "log")
|
67
|
-
log_dir = log_dir / logger.resolve_logger_dirname()
|
68
|
-
# ^ NOTE: Logger must have a `name` attribute, as this is
|
69
|
-
# the discriminator for the logger registry
|
70
|
-
log_dir.mkdir(exist_ok=True)
|
71
|
-
|
72
|
-
return log_dir
|
@@ -1,15 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
6
|
-
from nshtrainer._directory import (
|
7
|
-
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
8
|
-
)
|
9
|
-
from nshtrainer._directory import LoggerConfig as LoggerConfig
|
10
|
-
|
11
|
-
__all__ = [
|
12
|
-
"DirectoryConfig",
|
13
|
-
"DirectorySetupCallbackConfig",
|
14
|
-
"LoggerConfig",
|
15
|
-
]
|
File without changes
|