nshtrainer 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +14 -0
- nshtrainer/_hf_hub.py +3 -11
- nshtrainer/callbacks/wandb_upload_code.py +5 -17
- nshtrainer/configs/__init__.py +1 -5
- nshtrainer/configs/trainer/__init__.py +4 -2
- nshtrainer/configs/trainer/_config/__init__.py +4 -2
- nshtrainer/trainer/_config.py +517 -71
- nshtrainer/trainer/signal_connector.py +12 -7
- nshtrainer/trainer/trainer.py +1 -0
- nshtrainer/util/_environment_info.py +14 -6
- nshtrainer/util/code_upload.py +40 -0
- {nshtrainer-1.3.4.dist-info → nshtrainer-1.3.6.dist-info}/METADATA +2 -2
- {nshtrainer-1.3.4.dist-info → nshtrainer-1.3.6.dist-info}/RECORD +14 -15
- nshtrainer/_directory.py +0 -72
- nshtrainer/configs/_directory/__init__.py +0 -15
- {nshtrainer-1.3.4.dist-info → nshtrainer-1.3.6.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py
CHANGED
@@ -19,3 +19,17 @@ try:
|
|
19
19
|
from . import configs as configs
|
20
20
|
except BaseException:
|
21
21
|
pass
|
22
|
+
|
23
|
+
try:
|
24
|
+
from importlib.metadata import PackageNotFoundError, version
|
25
|
+
except ImportError:
|
26
|
+
# For Python <3.8
|
27
|
+
from importlib_metadata import ( # pyright: ignore[reportMissingImports]
|
28
|
+
PackageNotFoundError,
|
29
|
+
version,
|
30
|
+
)
|
31
|
+
|
32
|
+
try:
|
33
|
+
__version__ = version(__name__)
|
34
|
+
except PackageNotFoundError:
|
35
|
+
__version__ = "unknown"
|
nshtrainer/_hf_hub.py
CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import contextlib
|
4
4
|
import logging
|
5
|
-
import os
|
6
5
|
import re
|
7
6
|
from dataclasses import dataclass
|
8
7
|
from functools import cached_property
|
@@ -10,7 +9,6 @@ from pathlib import Path
|
|
10
9
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
11
10
|
|
12
11
|
import nshconfig as C
|
13
|
-
from nshrunner._env import SNAPSHOT_DIR
|
14
12
|
from typing_extensions import assert_never, override
|
15
13
|
|
16
14
|
from ._callback import NTCallbackBase
|
@@ -19,6 +17,7 @@ from .callbacks.base import (
|
|
19
17
|
CallbackMetadataConfig,
|
20
18
|
callback_registry,
|
21
19
|
)
|
20
|
+
from .util.code_upload import get_code_dir
|
22
21
|
|
23
22
|
if TYPE_CHECKING:
|
24
23
|
from huggingface_hub import HfApi # noqa: F401
|
@@ -319,20 +318,13 @@ class HFHubCallback(NTCallbackBase):
|
|
319
318
|
def _save_code(self):
|
320
319
|
# If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
|
321
320
|
# then upload all contents within the snapshot directory to the repository.
|
322
|
-
if
|
321
|
+
if (snapshot_dir := get_code_dir()) is None:
|
323
322
|
log.debug("No snapshot directory found. Skipping upload.")
|
324
323
|
return
|
325
324
|
|
326
325
|
with self._with_error_handling("save code"):
|
327
|
-
snapshot_dir = Path(snapshot_dir)
|
328
|
-
if not snapshot_dir.exists() or not snapshot_dir.is_dir():
|
329
|
-
log.warning(
|
330
|
-
f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
|
331
|
-
)
|
332
|
-
return
|
333
|
-
|
334
326
|
self.api.upload_folder(
|
335
|
-
folder_path=str(snapshot_dir),
|
327
|
+
folder_path=str(snapshot_dir.absolute()),
|
336
328
|
repo_id=self.repo_id,
|
337
329
|
repo_type="model",
|
338
330
|
path_in_repo="code", # Prefix with "code" folder
|
@@ -1,16 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import os
|
5
|
-
from pathlib import Path
|
6
4
|
from typing import Literal, cast
|
7
5
|
|
8
6
|
from lightning.pytorch import LightningModule, Trainer
|
9
7
|
from lightning.pytorch.callbacks.callback import Callback
|
10
8
|
from lightning.pytorch.loggers import WandbLogger
|
11
|
-
from nshrunner._env import SNAPSHOT_DIR
|
12
9
|
from typing_extensions import final, override
|
13
10
|
|
11
|
+
from ..util.code_upload import get_code_dir
|
14
12
|
from .base import CallbackConfigBase, callback_registry
|
15
13
|
|
16
14
|
log = logging.getLogger(__name__)
|
@@ -62,22 +60,12 @@ class WandbUploadCodeCallback(Callback):
|
|
62
60
|
log.warning("Wandb logger not found. Skipping code upload.")
|
63
61
|
return
|
64
62
|
|
65
|
-
|
66
|
-
|
67
|
-
run = cast(Run, logger.experiment)
|
68
|
-
|
69
|
-
# If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
|
70
|
-
# then upload all contents within the snapshot directory to the repository.
|
71
|
-
if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
|
72
|
-
log.debug("No snapshot directory found. Skipping upload.")
|
63
|
+
if (snapshot_dir := get_code_dir()) is None:
|
64
|
+
log.info("No nshrunner snapshot found. Skipping code upload.")
|
73
65
|
return
|
74
66
|
|
75
|
-
|
76
|
-
if not snapshot_dir.exists() or not snapshot_dir.is_dir():
|
77
|
-
log.warning(
|
78
|
-
f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
|
79
|
-
)
|
80
|
-
return
|
67
|
+
from wandb.wandb_run import Run
|
81
68
|
|
69
|
+
run = cast(Run, logger.experiment)
|
82
70
|
log.info(f"Uploading code from snapshot directory '{snapshot_dir}'")
|
83
71
|
run.log_code(str(snapshot_dir.absolute()))
|
nshtrainer/configs/__init__.py
CHANGED
@@ -5,7 +5,6 @@ __codegen__ = True
|
|
5
5
|
from nshtrainer import MetricConfig as MetricConfig
|
6
6
|
from nshtrainer import TrainerConfig as TrainerConfig
|
7
7
|
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
8
|
-
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
9
8
|
from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
|
10
9
|
from nshtrainer._hf_hub import (
|
11
10
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
@@ -126,9 +125,9 @@ from nshtrainer.trainer._config import (
|
|
126
125
|
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
127
126
|
)
|
128
127
|
from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
|
128
|
+
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
129
129
|
from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
|
130
130
|
from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
|
131
|
-
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
132
131
|
from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
|
133
132
|
from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
|
134
133
|
from nshtrainer.trainer.accelerator import (
|
@@ -227,7 +226,6 @@ from nshtrainer.util.config import EpochsConfig as EpochsConfig
|
|
227
226
|
from nshtrainer.util.config import StepsConfig as StepsConfig
|
228
227
|
|
229
228
|
from . import _checkpoint as _checkpoint
|
230
|
-
from . import _directory as _directory
|
231
229
|
from . import _hf_hub as _hf_hub
|
232
230
|
from . import callbacks as callbacks
|
233
231
|
from . import loggers as loggers
|
@@ -338,7 +336,6 @@ __all__ = [
|
|
338
336
|
"RpropConfig",
|
339
337
|
"SGDConfig",
|
340
338
|
"SLURMEnvironmentPlugin",
|
341
|
-
"SanityCheckingConfig",
|
342
339
|
"SharedParametersCallbackConfig",
|
343
340
|
"SiLUNonlinearityConfig",
|
344
341
|
"SigmoidNonlinearityConfig",
|
@@ -367,7 +364,6 @@ __all__ = [
|
|
367
364
|
"XLAEnvironmentPlugin",
|
368
365
|
"XLAPluginConfig",
|
369
366
|
"_checkpoint",
|
370
|
-
"_directory",
|
371
367
|
"_hf_hub",
|
372
368
|
"accelerator_registry",
|
373
369
|
"callback_registry",
|
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
|
|
22
22
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
23
23
|
)
|
24
24
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
25
|
+
from nshtrainer.trainer._config import (
|
26
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
27
|
+
)
|
25
28
|
from nshtrainer.trainer._config import (
|
26
29
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
27
30
|
)
|
@@ -51,7 +54,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
|
51
54
|
from nshtrainer.trainer._config import (
|
52
55
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
53
56
|
)
|
54
|
-
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
55
57
|
from nshtrainer.trainer._config import (
|
56
58
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
57
59
|
)
|
@@ -152,6 +154,7 @@ __all__ = [
|
|
152
154
|
"DebugFlagCallbackConfig",
|
153
155
|
"DeepSpeedPluginConfig",
|
154
156
|
"DirectoryConfig",
|
157
|
+
"DirectorySetupCallbackConfig",
|
155
158
|
"DistributedPredictionWriterConfig",
|
156
159
|
"DoublePrecisionPluginConfig",
|
157
160
|
"EarlyStoppingCallbackConfig",
|
@@ -180,7 +183,6 @@ __all__ = [
|
|
180
183
|
"ProfilerConfig",
|
181
184
|
"RLPSanityChecksCallbackConfig",
|
182
185
|
"SLURMEnvironmentPlugin",
|
183
|
-
"SanityCheckingConfig",
|
184
186
|
"SharedParametersCallbackConfig",
|
185
187
|
"StrategyConfig",
|
186
188
|
"StrategyConfigBase",
|
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
|
|
18
18
|
DebugFlagCallbackConfig as DebugFlagCallbackConfig,
|
19
19
|
)
|
20
20
|
from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
|
21
|
+
from nshtrainer.trainer._config import (
|
22
|
+
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
23
|
+
)
|
21
24
|
from nshtrainer.trainer._config import (
|
22
25
|
EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
|
23
26
|
)
|
@@ -48,7 +51,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
|
|
48
51
|
from nshtrainer.trainer._config import (
|
49
52
|
RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
|
50
53
|
)
|
51
|
-
from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
|
52
54
|
from nshtrainer.trainer._config import (
|
53
55
|
SharedParametersCallbackConfig as SharedParametersCallbackConfig,
|
54
56
|
)
|
@@ -70,6 +72,7 @@ __all__ = [
|
|
70
72
|
"CheckpointSavingConfig",
|
71
73
|
"DebugFlagCallbackConfig",
|
72
74
|
"DirectoryConfig",
|
75
|
+
"DirectorySetupCallbackConfig",
|
73
76
|
"EarlyStoppingCallbackConfig",
|
74
77
|
"EnvironmentConfig",
|
75
78
|
"GradientClippingConfig",
|
@@ -86,7 +89,6 @@ __all__ = [
|
|
86
89
|
"PluginConfig",
|
87
90
|
"ProfilerConfig",
|
88
91
|
"RLPSanityChecksCallbackConfig",
|
89
|
-
"SanityCheckingConfig",
|
90
92
|
"SharedParametersCallbackConfig",
|
91
93
|
"StrategyConfig",
|
92
94
|
"TensorboardLoggerConfig",
|
nshtrainer/trainer/_config.py
CHANGED
@@ -26,7 +26,6 @@ from lightning.pytorch.profilers import Profiler
|
|
26
26
|
from lightning.pytorch.strategies.strategy import Strategy
|
27
27
|
from typing_extensions import TypeAliasType, TypedDict, override
|
28
28
|
|
29
|
-
from .._directory import DirectoryConfig
|
30
29
|
from .._hf_hub import HuggingFaceHubConfig
|
31
30
|
from ..callbacks import (
|
32
31
|
BestCheckpointCallbackConfig,
|
@@ -38,6 +37,7 @@ from ..callbacks import (
|
|
38
37
|
)
|
39
38
|
from ..callbacks.base import CallbackConfigBase
|
40
39
|
from ..callbacks.debug_flag import DebugFlagCallbackConfig
|
40
|
+
from ..callbacks.directory_setup import DirectorySetupCallbackConfig
|
41
41
|
from ..callbacks.log_epoch import LogEpochCallbackConfig
|
42
42
|
from ..callbacks.lr_monitor import LearningRateMonitorConfig
|
43
43
|
from ..callbacks.metric_validation import MetricValidationCallbackConfig
|
@@ -352,19 +352,74 @@ class LightningTrainerKwargs(TypedDict, total=False):
|
|
352
352
|
"""
|
353
353
|
|
354
354
|
|
355
|
-
|
356
|
-
|
355
|
+
DEFAULT_LOGDIR_BASENAME = "nshtrainer_logs"
|
356
|
+
"""Default base name for the log directory."""
|
357
|
+
|
358
|
+
|
359
|
+
class DirectoryConfig(C.Config):
|
360
|
+
project_root: Path | None = None
|
357
361
|
"""
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
Valid values are: "disable", "warn", "error".
|
362
|
+
Root directory for this project.
|
363
|
+
|
364
|
+
This isn't specific to the current run; it is the parent directory of all runs.
|
362
365
|
"""
|
363
366
|
|
367
|
+
logdir_basename: str = DEFAULT_LOGDIR_BASENAME
|
368
|
+
"""Base name for the log directory."""
|
369
|
+
|
370
|
+
setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
|
371
|
+
"""Configuration for the directory setup PyTorch Lightning callback."""
|
372
|
+
|
373
|
+
def resolve_run_root_directory(self, run_id: str) -> Path:
|
374
|
+
if (project_root_dir := self.project_root) is None:
|
375
|
+
project_root_dir = Path.cwd()
|
376
|
+
|
377
|
+
# The default base dir is $CWD/{logdir_basename}/{id}/
|
378
|
+
base_dir = project_root_dir / self.logdir_basename
|
379
|
+
base_dir.mkdir(exist_ok=True)
|
380
|
+
|
381
|
+
# Add a .gitignore file to the {logdir_basename} directory
|
382
|
+
# which will ignore all files except for the .gitignore file itself
|
383
|
+
gitignore_path = base_dir / ".gitignore"
|
384
|
+
if not gitignore_path.exists():
|
385
|
+
gitignore_path.touch()
|
386
|
+
gitignore_path.write_text("*\n")
|
387
|
+
|
388
|
+
base_dir = base_dir / run_id
|
389
|
+
base_dir.mkdir(exist_ok=True)
|
390
|
+
|
391
|
+
return base_dir
|
392
|
+
|
393
|
+
def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
|
394
|
+
# The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
|
395
|
+
if (subdir := getattr(self, subdirectory, None)) is not None:
|
396
|
+
assert isinstance(subdir, Path), (
|
397
|
+
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
398
|
+
)
|
399
|
+
return subdir
|
400
|
+
|
401
|
+
dir = self.resolve_run_root_directory(run_id)
|
402
|
+
dir = dir / subdirectory
|
403
|
+
dir.mkdir(exist_ok=True)
|
404
|
+
return dir
|
405
|
+
|
406
|
+
def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
|
407
|
+
if (log_dir := logger.log_dir) is not None:
|
408
|
+
return log_dir
|
409
|
+
|
410
|
+
# Save to {logdir_basename}/{id}/log/{logger name}
|
411
|
+
log_dir = self.resolve_subdirectory(run_id, "log")
|
412
|
+
log_dir = log_dir / logger.resolve_logger_dirname()
|
413
|
+
# ^ NOTE: Logger must have a `name` attribute, as this is
|
414
|
+
# the discriminator for the logger registry
|
415
|
+
log_dir.mkdir(exist_ok=True)
|
416
|
+
|
417
|
+
return log_dir
|
418
|
+
|
364
419
|
|
365
420
|
class TrainerConfig(C.Config):
|
366
421
|
# region Active Run Configuration
|
367
|
-
id: str
|
422
|
+
id: Annotated[str, C.AllowMissing()] = C.MISSING
|
368
423
|
"""ID of the run."""
|
369
424
|
name: list[str] = []
|
370
425
|
"""Run name in parts. Full name is constructed by joining the parts with spaces."""
|
@@ -393,39 +448,6 @@ class TrainerConfig(C.Config):
|
|
393
448
|
|
394
449
|
directory: DirectoryConfig = DirectoryConfig()
|
395
450
|
"""Directory configuration options."""
|
396
|
-
|
397
|
-
_rng: ClassVar[np.random.Generator | None] = None
|
398
|
-
|
399
|
-
@classmethod
|
400
|
-
def generate_id(cls, *, length: int = 8) -> str:
|
401
|
-
"""
|
402
|
-
Generate a random ID of specified length.
|
403
|
-
|
404
|
-
"""
|
405
|
-
if (rng := cls._rng) is None:
|
406
|
-
rng = np.random.default_rng()
|
407
|
-
|
408
|
-
alphabet = list(string.ascii_lowercase + string.digits)
|
409
|
-
|
410
|
-
id = "".join(rng.choice(alphabet) for _ in range(length))
|
411
|
-
return id
|
412
|
-
|
413
|
-
@classmethod
|
414
|
-
def set_seed(cls, seed: int | None = None) -> None:
|
415
|
-
"""
|
416
|
-
Set the seed for the random number generator.
|
417
|
-
|
418
|
-
Args:
|
419
|
-
seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
|
420
|
-
|
421
|
-
Returns:
|
422
|
-
None
|
423
|
-
"""
|
424
|
-
if seed is None:
|
425
|
-
seed = int(time.time() * 1000)
|
426
|
-
log.critical(f"Seeding {cls.__name__} with seed {seed}")
|
427
|
-
cls._rng = np.random.default_rng(seed)
|
428
|
-
|
429
451
|
# endregion
|
430
452
|
|
431
453
|
primary_metric: MetricConfig | None = None
|
@@ -755,40 +777,40 @@ class TrainerConfig(C.Config):
|
|
755
777
|
None,
|
756
778
|
)
|
757
779
|
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
yield self.lr_monitor
|
763
|
-
yield from (
|
764
|
-
logger_config
|
765
|
-
for logger_config in self.enabled_loggers()
|
766
|
-
if logger_config is not None
|
767
|
-
and isinstance(logger_config, CallbackConfigBase)
|
768
|
-
)
|
769
|
-
yield self.log_epoch
|
770
|
-
yield self.log_norms
|
771
|
-
yield self.hf_hub
|
772
|
-
yield self.shared_parameters
|
773
|
-
yield self.reduce_lr_on_plateau_sanity_checking
|
774
|
-
yield self.auto_set_debug_flag
|
775
|
-
yield self.auto_validate_metrics
|
776
|
-
yield from self.callbacks
|
780
|
+
# region Helper Methods
|
781
|
+
def id_(self, value: str):
|
782
|
+
"""
|
783
|
+
Set the id for the trainer configuration in-place.
|
777
784
|
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
785
|
+
Parameters
|
786
|
+
----------
|
787
|
+
value : str
|
788
|
+
The id value to set
|
782
789
|
|
783
|
-
|
784
|
-
|
790
|
+
Returns
|
791
|
+
-------
|
792
|
+
self
|
793
|
+
Returns self for method chaining
|
794
|
+
"""
|
795
|
+
self.id = value
|
796
|
+
return self
|
785
797
|
|
786
|
-
def
|
787
|
-
|
788
|
-
|
789
|
-
|
798
|
+
def with_id(self, value: str):
|
799
|
+
"""
|
800
|
+
Create a copy of the current configuration with an updated id.
|
801
|
+
|
802
|
+
Parameters
|
803
|
+
----------
|
804
|
+
value : str
|
805
|
+
The id value to set
|
806
|
+
|
807
|
+
Returns
|
808
|
+
-------
|
809
|
+
TrainerConfig
|
810
|
+
A new instance of the configuration with the updated id
|
811
|
+
"""
|
812
|
+
return copy.deepcopy(self).id_(value)
|
790
813
|
|
791
|
-
# region Helper Methods
|
792
814
|
def fast_dev_run_(self, value: int | bool = True, /):
|
793
815
|
"""
|
794
816
|
Enables fast_dev_run mode for the trainer.
|
@@ -831,6 +853,349 @@ class TrainerConfig(C.Config):
|
|
831
853
|
"""
|
832
854
|
return copy.deepcopy(self).project_root_(project_root)
|
833
855
|
|
856
|
+
def name_(self, *parts: str):
|
857
|
+
"""
|
858
|
+
Set the name for the trainer configuration in-place.
|
859
|
+
|
860
|
+
Parameters
|
861
|
+
----------
|
862
|
+
*parts : str
|
863
|
+
The parts of the name to set. Will be joined with spaces.
|
864
|
+
|
865
|
+
Returns
|
866
|
+
-------
|
867
|
+
self
|
868
|
+
Returns self for method chaining
|
869
|
+
"""
|
870
|
+
self.name = list(parts)
|
871
|
+
return self
|
872
|
+
|
873
|
+
def with_name(self, *parts: str):
|
874
|
+
"""
|
875
|
+
Create a copy of the current configuration with an updated name.
|
876
|
+
|
877
|
+
Parameters
|
878
|
+
----------
|
879
|
+
*parts : str
|
880
|
+
The parts of the name to set. Will be joined with spaces.
|
881
|
+
|
882
|
+
Returns
|
883
|
+
-------
|
884
|
+
TrainerConfig
|
885
|
+
A new instance of the configuration with the updated name
|
886
|
+
"""
|
887
|
+
return copy.deepcopy(self).name_(*parts)
|
888
|
+
|
889
|
+
def project_(self, project: str | None):
|
890
|
+
"""
|
891
|
+
Set the project name for the trainer configuration in-place.
|
892
|
+
|
893
|
+
Parameters
|
894
|
+
----------
|
895
|
+
project : str | None
|
896
|
+
The project name to set
|
897
|
+
|
898
|
+
Returns
|
899
|
+
-------
|
900
|
+
self
|
901
|
+
Returns self for method chaining
|
902
|
+
"""
|
903
|
+
self.project = project
|
904
|
+
return self
|
905
|
+
|
906
|
+
def with_project(self, project: str | None):
|
907
|
+
"""
|
908
|
+
Create a copy of the current configuration with an updated project name.
|
909
|
+
|
910
|
+
Parameters
|
911
|
+
----------
|
912
|
+
project : str | None
|
913
|
+
The project name to set
|
914
|
+
|
915
|
+
Returns
|
916
|
+
-------
|
917
|
+
TrainerConfig
|
918
|
+
A new instance of the configuration with the updated project name
|
919
|
+
"""
|
920
|
+
return copy.deepcopy(self).project_(project)
|
921
|
+
|
922
|
+
def tags_(self, *tags: str):
|
923
|
+
"""
|
924
|
+
Set the tags for the trainer configuration in-place.
|
925
|
+
|
926
|
+
Parameters
|
927
|
+
----------
|
928
|
+
*tags : str
|
929
|
+
The tags to set
|
930
|
+
|
931
|
+
Returns
|
932
|
+
-------
|
933
|
+
self
|
934
|
+
Returns self for method chaining
|
935
|
+
"""
|
936
|
+
self.tags = list(tags)
|
937
|
+
return self
|
938
|
+
|
939
|
+
def with_tags(self, *tags: str):
|
940
|
+
"""
|
941
|
+
Create a copy of the current configuration with updated tags.
|
942
|
+
|
943
|
+
Parameters
|
944
|
+
----------
|
945
|
+
*tags : str
|
946
|
+
The tags to set
|
947
|
+
|
948
|
+
Returns
|
949
|
+
-------
|
950
|
+
TrainerConfig
|
951
|
+
A new instance of the configuration with the updated tags
|
952
|
+
"""
|
953
|
+
return copy.deepcopy(self).tags_(*tags)
|
954
|
+
|
955
|
+
def add_tags_(self, *tags: str):
|
956
|
+
"""
|
957
|
+
Add tags to the trainer configuration in-place.
|
958
|
+
|
959
|
+
Parameters
|
960
|
+
----------
|
961
|
+
*tags : str
|
962
|
+
The tags to add
|
963
|
+
|
964
|
+
Returns
|
965
|
+
-------
|
966
|
+
self
|
967
|
+
Returns self for method chaining
|
968
|
+
"""
|
969
|
+
self.tags.extend(tags)
|
970
|
+
return self
|
971
|
+
|
972
|
+
def with_added_tags(self, *tags: str):
|
973
|
+
"""
|
974
|
+
Create a copy of the current configuration with additional tags.
|
975
|
+
|
976
|
+
Parameters
|
977
|
+
----------
|
978
|
+
*tags : str
|
979
|
+
The tags to add
|
980
|
+
|
981
|
+
Returns
|
982
|
+
-------
|
983
|
+
TrainerConfig
|
984
|
+
A new instance of the configuration with the additional tags
|
985
|
+
"""
|
986
|
+
return copy.deepcopy(self).add_tags_(*tags)
|
987
|
+
|
988
|
+
def notes_(self, *notes: str):
|
989
|
+
"""
|
990
|
+
Set the notes for the trainer configuration in-place.
|
991
|
+
|
992
|
+
Parameters
|
993
|
+
----------
|
994
|
+
*notes : str
|
995
|
+
The notes to set
|
996
|
+
|
997
|
+
Returns
|
998
|
+
-------
|
999
|
+
self
|
1000
|
+
Returns self for method chaining
|
1001
|
+
"""
|
1002
|
+
self.notes = list(notes)
|
1003
|
+
return self
|
1004
|
+
|
1005
|
+
def with_notes(self, *notes: str):
|
1006
|
+
"""
|
1007
|
+
Create a copy of the current configuration with updated notes.
|
1008
|
+
|
1009
|
+
Parameters
|
1010
|
+
----------
|
1011
|
+
*notes : str
|
1012
|
+
The notes to set
|
1013
|
+
|
1014
|
+
Returns
|
1015
|
+
-------
|
1016
|
+
TrainerConfig
|
1017
|
+
A new instance of the configuration with the updated notes
|
1018
|
+
"""
|
1019
|
+
return copy.deepcopy(self).notes_(*notes)
|
1020
|
+
|
1021
|
+
def add_notes_(self, *notes: str):
|
1022
|
+
"""
|
1023
|
+
Add notes to the trainer configuration in-place.
|
1024
|
+
|
1025
|
+
Parameters
|
1026
|
+
----------
|
1027
|
+
*notes : str
|
1028
|
+
The notes to add
|
1029
|
+
|
1030
|
+
Returns
|
1031
|
+
-------
|
1032
|
+
self
|
1033
|
+
Returns self for method chaining
|
1034
|
+
"""
|
1035
|
+
self.notes.extend(notes)
|
1036
|
+
return self
|
1037
|
+
|
1038
|
+
def with_added_notes(self, *notes: str):
|
1039
|
+
"""
|
1040
|
+
Create a copy of the current configuration with additional notes.
|
1041
|
+
|
1042
|
+
Parameters
|
1043
|
+
----------
|
1044
|
+
*notes : str
|
1045
|
+
The notes to add
|
1046
|
+
|
1047
|
+
Returns
|
1048
|
+
-------
|
1049
|
+
TrainerConfig
|
1050
|
+
A new instance of the configuration with the additional notes
|
1051
|
+
"""
|
1052
|
+
return copy.deepcopy(self).add_notes_(*notes)
|
1053
|
+
|
1054
|
+
def meta_(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
|
1055
|
+
"""
|
1056
|
+
Update the `meta` dictionary in-place with the provided key-value pairs.
|
1057
|
+
|
1058
|
+
This method allows updating the meta information associated with the trainer
|
1059
|
+
configuration by either passing a dictionary or keyword arguments.
|
1060
|
+
|
1061
|
+
Parameters
|
1062
|
+
----------
|
1063
|
+
meta : dict[str, Any] | None, optional
|
1064
|
+
A dictionary containing meta information to be added, by default None
|
1065
|
+
**kwargs : Any
|
1066
|
+
Additional key-value pairs to be added to the meta dictionary
|
1067
|
+
|
1068
|
+
Returns
|
1069
|
+
-------
|
1070
|
+
self
|
1071
|
+
Returns self for method chaining
|
1072
|
+
"""
|
1073
|
+
if meta is not None:
|
1074
|
+
self.meta.update(meta)
|
1075
|
+
self.meta.update(kwargs)
|
1076
|
+
return self
|
1077
|
+
|
1078
|
+
def with_meta(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
|
1079
|
+
"""
|
1080
|
+
Create a copy of the current configuration with updated meta information.
|
1081
|
+
|
1082
|
+
This method is similar to `meta_`, but it returns a new instance of the configuration
|
1083
|
+
with the updated meta information instead of modifying the current instance.
|
1084
|
+
|
1085
|
+
Parameters
|
1086
|
+
----------
|
1087
|
+
meta : dict[str, Any] | None, optional
|
1088
|
+
A dictionary containing meta information to be added, by default None
|
1089
|
+
**kwargs : Any
|
1090
|
+
Additional key-value pairs to be added to the meta dictionary
|
1091
|
+
|
1092
|
+
Returns
|
1093
|
+
-------
|
1094
|
+
TrainerConfig
|
1095
|
+
A new instance of the configuration with updated meta information
|
1096
|
+
"""
|
1097
|
+
|
1098
|
+
return self.model_copy(deep=True).meta_(meta, **kwargs)
|
1099
|
+
|
1100
|
+
def debug_(self, value: bool = True):
|
1101
|
+
"""
|
1102
|
+
Set the debug flag for the trainer configuration in-place.
|
1103
|
+
|
1104
|
+
Parameters
|
1105
|
+
----------
|
1106
|
+
value : bool, optional
|
1107
|
+
The debug flag value to set, by default True
|
1108
|
+
|
1109
|
+
Returns
|
1110
|
+
-------
|
1111
|
+
self
|
1112
|
+
Returns self for method chaining
|
1113
|
+
"""
|
1114
|
+
self.debug = value
|
1115
|
+
return self
|
1116
|
+
|
1117
|
+
def with_debug(self, value: bool = True):
|
1118
|
+
"""
|
1119
|
+
Create a copy of the current configuration with an updated debug flag.
|
1120
|
+
|
1121
|
+
Parameters
|
1122
|
+
----------
|
1123
|
+
value : bool, optional
|
1124
|
+
The debug flag value to set, by default True
|
1125
|
+
|
1126
|
+
Returns
|
1127
|
+
-------
|
1128
|
+
TrainerConfig
|
1129
|
+
A new instance of the configuration with the updated debug flag
|
1130
|
+
"""
|
1131
|
+
return copy.deepcopy(self).debug_(value)
|
1132
|
+
|
1133
|
+
def ckpt_path_(self, path: Literal["none"] | str | Path | None):
|
1134
|
+
"""
|
1135
|
+
Set the checkpoint path for the trainer configuration in-place.
|
1136
|
+
|
1137
|
+
Parameters
|
1138
|
+
----------
|
1139
|
+
path : Literal["none"] | str | Path | None
|
1140
|
+
The checkpoint path to set
|
1141
|
+
|
1142
|
+
Returns
|
1143
|
+
-------
|
1144
|
+
self
|
1145
|
+
Returns self for method chaining
|
1146
|
+
"""
|
1147
|
+
self.ckpt_path = path
|
1148
|
+
return self
|
1149
|
+
|
1150
|
+
def with_ckpt_path(self, path: Literal["none"] | str | Path | None):
|
1151
|
+
"""
|
1152
|
+
Create a copy of the current configuration with an updated checkpoint path.
|
1153
|
+
|
1154
|
+
Parameters
|
1155
|
+
----------
|
1156
|
+
path : Literal["none"] | str | Path | None
|
1157
|
+
The checkpoint path to set
|
1158
|
+
|
1159
|
+
Returns
|
1160
|
+
-------
|
1161
|
+
TrainerConfig
|
1162
|
+
A new instance of the configuration with the updated checkpoint path
|
1163
|
+
"""
|
1164
|
+
return copy.deepcopy(self).ckpt_path_(path)
|
1165
|
+
|
1166
|
+
def barebones_(self, value: bool = True):
|
1167
|
+
"""
|
1168
|
+
Set the barebones flag for the trainer configuration in-place.
|
1169
|
+
|
1170
|
+
Parameters
|
1171
|
+
----------
|
1172
|
+
value : bool, optional
|
1173
|
+
The barebones flag value to set, by default True
|
1174
|
+
|
1175
|
+
Returns
|
1176
|
+
-------
|
1177
|
+
self
|
1178
|
+
Returns self for method chaining
|
1179
|
+
"""
|
1180
|
+
self.barebones = value
|
1181
|
+
return self
|
1182
|
+
|
1183
|
+
def with_barebones(self, value: bool = True):
|
1184
|
+
"""
|
1185
|
+
Create a copy of the current configuration with an updated barebones flag.
|
1186
|
+
|
1187
|
+
Parameters
|
1188
|
+
----------
|
1189
|
+
value : bool, optional
|
1190
|
+
The barebones flag value to set, by default True
|
1191
|
+
|
1192
|
+
Returns
|
1193
|
+
-------
|
1194
|
+
TrainerConfig
|
1195
|
+
A new instance of the configuration with the updated barebones flag
|
1196
|
+
"""
|
1197
|
+
return copy.deepcopy(self).barebones_(value)
|
1198
|
+
|
834
1199
|
def reset_run(
|
835
1200
|
self,
|
836
1201
|
*,
|
@@ -873,3 +1238,84 @@ class TrainerConfig(C.Config):
|
|
873
1238
|
return config
|
874
1239
|
|
875
1240
|
# endregion
|
1241
|
+
|
1242
|
+
# region Random ID Generation
|
1243
|
+
_rng: ClassVar[np.random.Generator | None] = None
|
1244
|
+
|
1245
|
+
@classmethod
|
1246
|
+
def generate_id(cls, *, length: int = 8) -> str:
|
1247
|
+
"""
|
1248
|
+
Generate a random ID of specified length.
|
1249
|
+
|
1250
|
+
"""
|
1251
|
+
if (rng := cls._rng) is None:
|
1252
|
+
rng = np.random.default_rng()
|
1253
|
+
|
1254
|
+
alphabet = list(string.ascii_lowercase + string.digits)
|
1255
|
+
|
1256
|
+
id = "".join(rng.choice(alphabet) for _ in range(length))
|
1257
|
+
return id
|
1258
|
+
|
1259
|
+
@classmethod
|
1260
|
+
def set_seed(cls, seed: int | None = None) -> None:
|
1261
|
+
"""
|
1262
|
+
Set the seed for the random number generator.
|
1263
|
+
|
1264
|
+
Args:
|
1265
|
+
seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
|
1266
|
+
|
1267
|
+
Returns:
|
1268
|
+
None
|
1269
|
+
"""
|
1270
|
+
if seed is None:
|
1271
|
+
seed = int(time.time() * 1000)
|
1272
|
+
log.critical(f"Seeding {cls.__name__} with seed {seed}")
|
1273
|
+
cls._rng = np.random.default_rng(seed)
|
1274
|
+
|
1275
|
+
# endregion
|
1276
|
+
|
1277
|
+
# region Internal Methods
|
1278
|
+
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
1279
|
+
yield self.directory.setup_callback
|
1280
|
+
yield self.early_stopping
|
1281
|
+
yield self.checkpoint_saving
|
1282
|
+
yield self.lr_monitor
|
1283
|
+
yield from (
|
1284
|
+
logger_config
|
1285
|
+
for logger_config in self.enabled_loggers()
|
1286
|
+
if logger_config is not None
|
1287
|
+
and isinstance(logger_config, CallbackConfigBase)
|
1288
|
+
)
|
1289
|
+
yield self.log_epoch
|
1290
|
+
yield self.log_norms
|
1291
|
+
yield self.hf_hub
|
1292
|
+
yield self.shared_parameters
|
1293
|
+
yield self.reduce_lr_on_plateau_sanity_checking
|
1294
|
+
yield self.auto_set_debug_flag
|
1295
|
+
yield self.auto_validate_metrics
|
1296
|
+
yield from self.callbacks
|
1297
|
+
|
1298
|
+
def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
|
1299
|
+
# Disable all loggers if barebones mode is enabled
|
1300
|
+
if self.barebones:
|
1301
|
+
return
|
1302
|
+
|
1303
|
+
yield from self.enabled_loggers()
|
1304
|
+
yield self.actsave_logger
|
1305
|
+
|
1306
|
+
def _nshtrainer_validate_before_run(self):
|
1307
|
+
# shared_parameters is not supported under barebones mode
|
1308
|
+
if self.barebones and self.shared_parameters:
|
1309
|
+
raise ValueError("shared_parameters is not supported under barebones mode")
|
1310
|
+
|
1311
|
+
def _nshtrainer_set_id_if_missing(self):
|
1312
|
+
"""
|
1313
|
+
Set the ID for the configuration object if it is missing.
|
1314
|
+
"""
|
1315
|
+
if self.id is C.MISSING:
|
1316
|
+
self.id = self.generate_id()
|
1317
|
+
log.info(f"TrainerConfig's run ID is missing, setting to {self.id}.")
|
1318
|
+
else:
|
1319
|
+
log.debug(f"TrainerConfig's run ID is already set to {self.id}.")
|
1320
|
+
|
1321
|
+
# endregion
|
@@ -14,7 +14,6 @@ from pathlib import Path
|
|
14
14
|
from types import FrameType
|
15
15
|
from typing import Any
|
16
16
|
|
17
|
-
import nshrunner as nr
|
18
17
|
import torch.utils.data
|
19
18
|
from lightning.fabric.plugins.environments.lsf import LSFEnvironment
|
20
19
|
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
@@ -34,6 +33,12 @@ _IS_WINDOWS = platform.system() == "Windows"
|
|
34
33
|
|
35
34
|
|
36
35
|
def _resolve_requeue_signals():
|
36
|
+
try:
|
37
|
+
import nshrunner as nr
|
38
|
+
except ImportError:
|
39
|
+
log.debug("nshrunner not found. Skipping signal requeueing.")
|
40
|
+
return None
|
41
|
+
|
37
42
|
if (session := nr.Session.from_current_session()) is None:
|
38
43
|
return None
|
39
44
|
|
@@ -52,9 +57,9 @@ class _SignalConnector(_LightningSignalConnector):
|
|
52
57
|
|
53
58
|
signals_set = set(signals)
|
54
59
|
valid_signals: set[signal.Signals] = signal.valid_signals()
|
55
|
-
assert signals_set.issubset(
|
56
|
-
valid_signals
|
57
|
-
)
|
60
|
+
assert signals_set.issubset(valid_signals), (
|
61
|
+
f"Invalid signal(s) found: {signals_set - valid_signals}"
|
62
|
+
)
|
58
63
|
return signals
|
59
64
|
|
60
65
|
def _compose_and_register(
|
@@ -241,9 +246,9 @@ class _SignalConnector(_LightningSignalConnector):
|
|
241
246
|
"Writing requeue script to exit script directory."
|
242
247
|
)
|
243
248
|
exit_script_dir = Path(exit_script_dir)
|
244
|
-
assert (
|
245
|
-
exit_script_dir
|
246
|
-
)
|
249
|
+
assert exit_script_dir.is_dir(), (
|
250
|
+
f"Exit script directory {exit_script_dir} does not exist"
|
251
|
+
)
|
247
252
|
|
248
253
|
exit_script_path = exit_script_dir / f"requeue_{job_id}.sh"
|
249
254
|
log.info(f"Writing requeue script to {exit_script_path}")
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -316,6 +316,7 @@ class Trainer(LightningTrainer):
|
|
316
316
|
f"Trainer hparams must either be an instance of {hparams_cls} or a mapping. "
|
317
317
|
f"Got {type(hparams)=} instead."
|
318
318
|
)
|
319
|
+
hparams._nshtrainer_set_id_if_missing()
|
319
320
|
hparams = hparams.model_deep_validate()
|
320
321
|
hparams._nshtrainer_validate_before_run()
|
321
322
|
|
@@ -356,12 +356,20 @@ class EnvironmentSnapshotConfig(C.Config):
|
|
356
356
|
|
357
357
|
@classmethod
|
358
358
|
def from_current_environment(cls):
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
359
|
+
try:
|
360
|
+
import nshrunner as nr
|
361
|
+
|
362
|
+
if (session := nr.Session.from_current_session()) is None:
|
363
|
+
log.warning("No active session found, skipping snapshot information")
|
364
|
+
return cls.empty()
|
365
|
+
|
366
|
+
draft = cls.draft()
|
367
|
+
draft.snapshot_dir = session.snapshot_dir
|
368
|
+
draft.modules = session.snapshot_modules
|
369
|
+
return draft.finalize()
|
370
|
+
except ImportError:
|
371
|
+
log.warning("nshrunner not found, skipping snapshot information")
|
372
|
+
return cls.empty()
|
365
373
|
|
366
374
|
|
367
375
|
class EnvironmentPackageConfig(C.Config):
|
@@ -0,0 +1,40 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
log = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
|
9
|
+
def get_code_dir() -> Path | None:
|
10
|
+
try:
|
11
|
+
import nshrunner as nr
|
12
|
+
|
13
|
+
if (session := nr.Session.from_current_session()) is None:
|
14
|
+
log.debug("No active session found. Skipping code upload.")
|
15
|
+
return None
|
16
|
+
|
17
|
+
# New versions of nshrunner will have the code_dir attribute
|
18
|
+
# in the session object. We should use that. Otherwise, use snapshot_dir.
|
19
|
+
try:
|
20
|
+
code_dir = session.code_dir # type: ignore
|
21
|
+
except AttributeError:
|
22
|
+
code_dir = session.snapshot_dir
|
23
|
+
|
24
|
+
if code_dir is None:
|
25
|
+
log.debug("No code directory found. Skipping code upload.")
|
26
|
+
return None
|
27
|
+
|
28
|
+
assert isinstance(code_dir, Path), (
|
29
|
+
f"Code directory should be a Path object. Got {type(code_dir)} instead."
|
30
|
+
)
|
31
|
+
if not code_dir.exists() or not code_dir.is_dir():
|
32
|
+
log.warning(
|
33
|
+
f"Code directory '{code_dir}' does not exist or is not a directory."
|
34
|
+
)
|
35
|
+
return None
|
36
|
+
|
37
|
+
return code_dir
|
38
|
+
except ImportError:
|
39
|
+
log.debug("nshrunner not found. Skipping code upload.")
|
40
|
+
return None
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: nshtrainer
|
3
|
-
Version: 1.3.
|
3
|
+
Version: 1.3.6
|
4
4
|
Summary:
|
5
5
|
Author: Nima Shoghi
|
6
6
|
Author-email: nimashoghi@gmail.com
|
@@ -15,7 +15,7 @@ Requires-Dist: GitPython ; extra == "extra"
|
|
15
15
|
Requires-Dist: huggingface-hub ; extra == "extra"
|
16
16
|
Requires-Dist: lightning
|
17
17
|
Requires-Dist: nshconfig (>0.39)
|
18
|
-
Requires-Dist: nshrunner
|
18
|
+
Requires-Dist: nshrunner ; extra == "extra"
|
19
19
|
Requires-Dist: nshutils ; extra == "extra"
|
20
20
|
Requires-Dist: numpy
|
21
21
|
Requires-Dist: packaging
|
@@ -1,11 +1,10 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
|
-
nshtrainer/__init__.py,sha256=
|
2
|
+
nshtrainer/__init__.py,sha256=RI_2B_IUWa10B6H5TAuWtE5FWX1X4ue-J4dTDaF2-lQ,1035
|
3
3
|
nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
|
6
|
-
nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
|
7
6
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
|
-
nshtrainer/_hf_hub.py,sha256=
|
7
|
+
nshtrainer/_hf_hub.py,sha256=kfN0wDxK5JWKKGZnX_706i0KXGhaS19p581LDTPxlRE,13996
|
9
8
|
nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
|
10
9
|
nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
|
11
10
|
nshtrainer/callbacks/base.py,sha256=K9aom1WVVRYxl-tHWgtmDUQZ1o63NgznvLsjauTKcCc,4225
|
@@ -30,13 +29,12 @@ nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0z
|
|
30
29
|
nshtrainer/callbacks/rlp_sanity_checks.py,sha256=Df9Prq2QKXnaeMBIvMQBhDhJTDeru5UbiuXJOJR16Gk,10050
|
31
30
|
nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGClU4t5kLt8XrY,3076
|
32
31
|
nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
|
33
|
-
nshtrainer/callbacks/wandb_upload_code.py,sha256=
|
32
|
+
nshtrainer/callbacks/wandb_upload_code.py,sha256=4X-mpiX5ghj9vnEreK2i8Xyvimqt0K-PNWA2HtT-B6I,1940
|
34
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
35
34
|
nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
|
36
|
-
nshtrainer/configs/__init__.py,sha256
|
35
|
+
nshtrainer/configs/__init__.py,sha256=-yJ5Uk9VkANqfk-QnX2aynL0jSf7cJQuQNzT1GAE1x8,15684
|
37
36
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
38
37
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
39
|
-
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
40
38
|
nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
|
41
39
|
nshtrainer/configs/callbacks/__init__.py,sha256=tP9urR73NIanyxpbi4EERsxOnGNiptbQpmsj-v53a38,4774
|
42
40
|
nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
|
@@ -85,8 +83,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
|
|
85
83
|
nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
|
86
84
|
nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
|
87
85
|
nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
|
88
|
-
nshtrainer/configs/trainer/__init__.py,sha256=
|
89
|
-
nshtrainer/configs/trainer/_config/__init__.py,sha256=
|
86
|
+
nshtrainer/configs/trainer/__init__.py,sha256=DM2PlB4WRDZ_dqEeW91LbKRFa4sIF_pETU0T9GYJ5-g,8073
|
87
|
+
nshtrainer/configs/trainer/_config/__init__.py,sha256=z5UpuXktBanLOYNkkbgbbHE06iQtcSuAKTpnx2TLmCo,3850
|
90
88
|
nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
|
91
89
|
nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
|
92
90
|
nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
|
@@ -135,7 +133,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
|
|
135
133
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
136
134
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
137
135
|
nshtrainer/trainer/__init__.py,sha256=jRaHdaFK8wxNrN1bleT9cf29iZahL_-XkWo5TWz2CmA,550
|
138
|
-
nshtrainer/trainer/_config.py,sha256=
|
136
|
+
nshtrainer/trainer/_config.py,sha256=x3YjP_0IykqSRh8YIilCxq3nPt_fZXoVcxfR13ulmV0,45578
|
139
137
|
nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
|
140
138
|
nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
|
141
139
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
@@ -146,11 +144,12 @@ nshtrainer/trainer/plugin/environment.py,sha256=SSXRWHjyFUA6oFx3duD_ZwhM59pWUjR1
|
|
146
144
|
nshtrainer/trainer/plugin/io.py,sha256=OmFSKLloMypletjaUr_Ptg6LS0ljqTVIp2o4Hm3eZoE,1926
|
147
145
|
nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLcOfPXnvH29s,663
|
148
146
|
nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
|
149
|
-
nshtrainer/trainer/signal_connector.py,sha256=
|
147
|
+
nshtrainer/trainer/signal_connector.py,sha256=ZgbSkbthoe8MYN6rBoFf-7UDpQtc9fs9pG_FNvTYSfs,10962
|
150
148
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
151
|
-
nshtrainer/trainer/trainer.py,sha256=
|
152
|
-
nshtrainer/util/_environment_info.py,sha256=
|
149
|
+
nshtrainer/trainer/trainer.py,sha256=iQWu0KfwLY-1q9EEsg0xPlyUN1fsJ9iXfSQbPmiMlac,24177
|
150
|
+
nshtrainer/util/_environment_info.py,sha256=j-wyEHKirsu3rIXTtqC2kLmIIkRe6obWjxPVWaqg2ow,24887
|
153
151
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
152
|
+
nshtrainer/util/code_upload.py,sha256=CpbZEBbA8EcBElUVoCPbP5zdwtNzJhS20RLaOB-q-2k,1257
|
154
153
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
155
154
|
nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
|
156
155
|
nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
|
@@ -160,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
160
159
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
161
160
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
162
161
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
163
|
-
nshtrainer-1.3.
|
164
|
-
nshtrainer-1.3.
|
165
|
-
nshtrainer-1.3.
|
162
|
+
nshtrainer-1.3.6.dist-info/METADATA,sha256=KidLM7J5P7mALTfYQsveSDjAOJZ-Gcq4_e1-Xgrms68,979
|
163
|
+
nshtrainer-1.3.6.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
164
|
+
nshtrainer-1.3.6.dist-info/RECORD,,
|
nshtrainer/_directory.py
DELETED
@@ -1,72 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
import nshconfig as C
|
7
|
-
|
8
|
-
from .callbacks.directory_setup import DirectorySetupCallbackConfig
|
9
|
-
from .loggers import LoggerConfig
|
10
|
-
|
11
|
-
log = logging.getLogger(__name__)
|
12
|
-
|
13
|
-
|
14
|
-
class DirectoryConfig(C.Config):
|
15
|
-
project_root: Path | None = None
|
16
|
-
"""
|
17
|
-
Root directory for this project.
|
18
|
-
|
19
|
-
This isn't specific to the run; it is the parent directory of all runs.
|
20
|
-
"""
|
21
|
-
|
22
|
-
logdir_basename: str = "nshtrainer"
|
23
|
-
"""Base name for the log directory."""
|
24
|
-
|
25
|
-
setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
|
26
|
-
"""Configuration for the directory setup PyTorch Lightning callback."""
|
27
|
-
|
28
|
-
def resolve_run_root_directory(self, run_id: str) -> Path:
|
29
|
-
if (project_root_dir := self.project_root) is None:
|
30
|
-
project_root_dir = Path.cwd()
|
31
|
-
|
32
|
-
# The default base dir is $CWD/{logdir_basename}/{id}/
|
33
|
-
base_dir = project_root_dir / self.logdir_basename
|
34
|
-
base_dir.mkdir(exist_ok=True)
|
35
|
-
|
36
|
-
# Add a .gitignore file to the {logdir_basename} directory
|
37
|
-
# which will ignore all files except for the .gitignore file itself
|
38
|
-
gitignore_path = base_dir / ".gitignore"
|
39
|
-
if not gitignore_path.exists():
|
40
|
-
gitignore_path.touch()
|
41
|
-
gitignore_path.write_text("*\n")
|
42
|
-
|
43
|
-
base_dir = base_dir / run_id
|
44
|
-
base_dir.mkdir(exist_ok=True)
|
45
|
-
|
46
|
-
return base_dir
|
47
|
-
|
48
|
-
def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
|
49
|
-
# The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
|
50
|
-
if (subdir := getattr(self, subdirectory, None)) is not None:
|
51
|
-
assert isinstance(subdir, Path), (
|
52
|
-
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
53
|
-
)
|
54
|
-
return subdir
|
55
|
-
|
56
|
-
dir = self.resolve_run_root_directory(run_id)
|
57
|
-
dir = dir / subdirectory
|
58
|
-
dir.mkdir(exist_ok=True)
|
59
|
-
return dir
|
60
|
-
|
61
|
-
def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
|
62
|
-
if (log_dir := logger.log_dir) is not None:
|
63
|
-
return log_dir
|
64
|
-
|
65
|
-
# Save to {logdir_basename}/{id}/log/{logger name}
|
66
|
-
log_dir = self.resolve_subdirectory(run_id, "log")
|
67
|
-
log_dir = log_dir / logger.resolve_logger_dirname()
|
68
|
-
# ^ NOTE: Logger must have a `name` attribute, as this is
|
69
|
-
# the discriminator for the logger registry
|
70
|
-
log_dir.mkdir(exist_ok=True)
|
71
|
-
|
72
|
-
return log_dir
|
@@ -1,15 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__codegen__ = True
|
4
|
-
|
5
|
-
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
6
|
-
from nshtrainer._directory import (
|
7
|
-
DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
|
8
|
-
)
|
9
|
-
from nshtrainer._directory import LoggerConfig as LoggerConfig
|
10
|
-
|
11
|
-
__all__ = [
|
12
|
-
"DirectoryConfig",
|
13
|
-
"DirectorySetupCallbackConfig",
|
14
|
-
"LoggerConfig",
|
15
|
-
]
|
File without changes
|