nshtrainer 1.3.4__py3-none-any.whl → 1.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/__init__.py CHANGED
@@ -19,3 +19,17 @@ try:
19
19
  from . import configs as configs
20
20
  except BaseException:
21
21
  pass
22
+
23
+ try:
24
+ from importlib.metadata import PackageNotFoundError, version
25
+ except ImportError:
26
+ # For Python <3.8
27
+ from importlib_metadata import ( # pyright: ignore[reportMissingImports]
28
+ PackageNotFoundError,
29
+ version,
30
+ )
31
+
32
+ try:
33
+ __version__ = version(__name__)
34
+ except PackageNotFoundError:
35
+ __version__ = "unknown"
nshtrainer/_hf_hub.py CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import contextlib
4
4
  import logging
5
- import os
6
5
  import re
7
6
  from dataclasses import dataclass
8
7
  from functools import cached_property
@@ -10,7 +9,6 @@ from pathlib import Path
10
9
  from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
11
10
 
12
11
  import nshconfig as C
13
- from nshrunner._env import SNAPSHOT_DIR
14
12
  from typing_extensions import assert_never, override
15
13
 
16
14
  from ._callback import NTCallbackBase
@@ -19,6 +17,7 @@ from .callbacks.base import (
19
17
  CallbackMetadataConfig,
20
18
  callback_registry,
21
19
  )
20
+ from .util.code_upload import get_code_dir
22
21
 
23
22
  if TYPE_CHECKING:
24
23
  from huggingface_hub import HfApi # noqa: F401
@@ -319,20 +318,13 @@ class HFHubCallback(NTCallbackBase):
319
318
  def _save_code(self):
320
319
  # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
321
320
  # then upload all contents within the snapshot directory to the repository.
322
- if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
321
+ if (snapshot_dir := get_code_dir()) is None:
323
322
  log.debug("No snapshot directory found. Skipping upload.")
324
323
  return
325
324
 
326
325
  with self._with_error_handling("save code"):
327
- snapshot_dir = Path(snapshot_dir)
328
- if not snapshot_dir.exists() or not snapshot_dir.is_dir():
329
- log.warning(
330
- f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
331
- )
332
- return
333
-
334
326
  self.api.upload_folder(
335
- folder_path=str(snapshot_dir),
327
+ folder_path=str(snapshot_dir.absolute()),
336
328
  repo_id=self.repo_id,
337
329
  repo_type="model",
338
330
  path_in_repo="code", # Prefix with "code" folder
@@ -1,16 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
- from pathlib import Path
6
4
  from typing import Literal, cast
7
5
 
8
6
  from lightning.pytorch import LightningModule, Trainer
9
7
  from lightning.pytorch.callbacks.callback import Callback
10
8
  from lightning.pytorch.loggers import WandbLogger
11
- from nshrunner._env import SNAPSHOT_DIR
12
9
  from typing_extensions import final, override
13
10
 
11
+ from ..util.code_upload import get_code_dir
14
12
  from .base import CallbackConfigBase, callback_registry
15
13
 
16
14
  log = logging.getLogger(__name__)
@@ -62,22 +60,12 @@ class WandbUploadCodeCallback(Callback):
62
60
  log.warning("Wandb logger not found. Skipping code upload.")
63
61
  return
64
62
 
65
- from wandb.wandb_run import Run
66
-
67
- run = cast(Run, logger.experiment)
68
-
69
- # If a snapshot has been taken (which can be detected using the SNAPSHOT_DIR env),
70
- # then upload all contents within the snapshot directory to the repository.
71
- if not (snapshot_dir := os.environ.get(SNAPSHOT_DIR)):
72
- log.debug("No snapshot directory found. Skipping upload.")
63
+ if (snapshot_dir := get_code_dir()) is None:
64
+ log.info("No nshrunner snapshot found. Skipping code upload.")
73
65
  return
74
66
 
75
- snapshot_dir = Path(snapshot_dir)
76
- if not snapshot_dir.exists() or not snapshot_dir.is_dir():
77
- log.warning(
78
- f"Snapshot directory '{snapshot_dir}' does not exist or is not a directory."
79
- )
80
- return
67
+ from wandb.wandb_run import Run
81
68
 
69
+ run = cast(Run, logger.experiment)
82
70
  log.info(f"Uploading code from snapshot directory '{snapshot_dir}'")
83
71
  run.log_code(str(snapshot_dir.absolute()))
@@ -5,7 +5,6 @@ __codegen__ = True
5
5
  from nshtrainer import MetricConfig as MetricConfig
6
6
  from nshtrainer import TrainerConfig as TrainerConfig
7
7
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
8
- from nshtrainer._directory import DirectoryConfig as DirectoryConfig
9
8
  from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
10
9
  from nshtrainer._hf_hub import (
11
10
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
@@ -126,9 +125,9 @@ from nshtrainer.trainer._config import (
126
125
  CheckpointCallbackConfig as CheckpointCallbackConfig,
127
126
  )
128
127
  from nshtrainer.trainer._config import CheckpointSavingConfig as CheckpointSavingConfig
128
+ from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
129
129
  from nshtrainer.trainer._config import EnvironmentConfig as EnvironmentConfig
130
130
  from nshtrainer.trainer._config import GradientClippingConfig as GradientClippingConfig
131
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
132
131
  from nshtrainer.trainer._config import StrategyConfig as StrategyConfig
133
132
  from nshtrainer.trainer.accelerator import CPUAcceleratorConfig as CPUAcceleratorConfig
134
133
  from nshtrainer.trainer.accelerator import (
@@ -227,7 +226,6 @@ from nshtrainer.util.config import EpochsConfig as EpochsConfig
227
226
  from nshtrainer.util.config import StepsConfig as StepsConfig
228
227
 
229
228
  from . import _checkpoint as _checkpoint
230
- from . import _directory as _directory
231
229
  from . import _hf_hub as _hf_hub
232
230
  from . import callbacks as callbacks
233
231
  from . import loggers as loggers
@@ -338,7 +336,6 @@ __all__ = [
338
336
  "RpropConfig",
339
337
  "SGDConfig",
340
338
  "SLURMEnvironmentPlugin",
341
- "SanityCheckingConfig",
342
339
  "SharedParametersCallbackConfig",
343
340
  "SiLUNonlinearityConfig",
344
341
  "SigmoidNonlinearityConfig",
@@ -367,7 +364,6 @@ __all__ = [
367
364
  "XLAEnvironmentPlugin",
368
365
  "XLAPluginConfig",
369
366
  "_checkpoint",
370
- "_directory",
371
367
  "_hf_hub",
372
368
  "accelerator_registry",
373
369
  "callback_registry",
@@ -22,6 +22,9 @@ from nshtrainer.trainer._config import (
22
22
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
23
23
  )
24
24
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
25
+ from nshtrainer.trainer._config import (
26
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
27
+ )
25
28
  from nshtrainer.trainer._config import (
26
29
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
27
30
  )
@@ -51,7 +54,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
51
54
  from nshtrainer.trainer._config import (
52
55
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
53
56
  )
54
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
55
57
  from nshtrainer.trainer._config import (
56
58
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
57
59
  )
@@ -152,6 +154,7 @@ __all__ = [
152
154
  "DebugFlagCallbackConfig",
153
155
  "DeepSpeedPluginConfig",
154
156
  "DirectoryConfig",
157
+ "DirectorySetupCallbackConfig",
155
158
  "DistributedPredictionWriterConfig",
156
159
  "DoublePrecisionPluginConfig",
157
160
  "EarlyStoppingCallbackConfig",
@@ -180,7 +183,6 @@ __all__ = [
180
183
  "ProfilerConfig",
181
184
  "RLPSanityChecksCallbackConfig",
182
185
  "SLURMEnvironmentPlugin",
183
- "SanityCheckingConfig",
184
186
  "SharedParametersCallbackConfig",
185
187
  "StrategyConfig",
186
188
  "StrategyConfigBase",
@@ -18,6 +18,9 @@ from nshtrainer.trainer._config import (
18
18
  DebugFlagCallbackConfig as DebugFlagCallbackConfig,
19
19
  )
20
20
  from nshtrainer.trainer._config import DirectoryConfig as DirectoryConfig
21
+ from nshtrainer.trainer._config import (
22
+ DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
23
+ )
21
24
  from nshtrainer.trainer._config import (
22
25
  EarlyStoppingCallbackConfig as EarlyStoppingCallbackConfig,
23
26
  )
@@ -48,7 +51,6 @@ from nshtrainer.trainer._config import ProfilerConfig as ProfilerConfig
48
51
  from nshtrainer.trainer._config import (
49
52
  RLPSanityChecksCallbackConfig as RLPSanityChecksCallbackConfig,
50
53
  )
51
- from nshtrainer.trainer._config import SanityCheckingConfig as SanityCheckingConfig
52
54
  from nshtrainer.trainer._config import (
53
55
  SharedParametersCallbackConfig as SharedParametersCallbackConfig,
54
56
  )
@@ -70,6 +72,7 @@ __all__ = [
70
72
  "CheckpointSavingConfig",
71
73
  "DebugFlagCallbackConfig",
72
74
  "DirectoryConfig",
75
+ "DirectorySetupCallbackConfig",
73
76
  "EarlyStoppingCallbackConfig",
74
77
  "EnvironmentConfig",
75
78
  "GradientClippingConfig",
@@ -86,7 +89,6 @@ __all__ = [
86
89
  "PluginConfig",
87
90
  "ProfilerConfig",
88
91
  "RLPSanityChecksCallbackConfig",
89
- "SanityCheckingConfig",
90
92
  "SharedParametersCallbackConfig",
91
93
  "StrategyConfig",
92
94
  "TensorboardLoggerConfig",
@@ -26,7 +26,6 @@ from lightning.pytorch.profilers import Profiler
26
26
  from lightning.pytorch.strategies.strategy import Strategy
27
27
  from typing_extensions import TypeAliasType, TypedDict, override
28
28
 
29
- from .._directory import DirectoryConfig
30
29
  from .._hf_hub import HuggingFaceHubConfig
31
30
  from ..callbacks import (
32
31
  BestCheckpointCallbackConfig,
@@ -38,6 +37,7 @@ from ..callbacks import (
38
37
  )
39
38
  from ..callbacks.base import CallbackConfigBase
40
39
  from ..callbacks.debug_flag import DebugFlagCallbackConfig
40
+ from ..callbacks.directory_setup import DirectorySetupCallbackConfig
41
41
  from ..callbacks.log_epoch import LogEpochCallbackConfig
42
42
  from ..callbacks.lr_monitor import LearningRateMonitorConfig
43
43
  from ..callbacks.metric_validation import MetricValidationCallbackConfig
@@ -352,19 +352,74 @@ class LightningTrainerKwargs(TypedDict, total=False):
352
352
  """
353
353
 
354
354
 
355
- class SanityCheckingConfig(C.Config):
356
- reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
355
+ DEFAULT_LOGDIR_BASENAME = "nshtrainer_logs"
356
+ """Default base name for the log directory."""
357
+
358
+
359
+ class DirectoryConfig(C.Config):
360
+ project_root: Path | None = None
357
361
  """
358
- If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
359
- - If the `interval` is step, it makes sure that validation is called every `frequency` steps.
360
- - If the `interval` is epoch, it makes sure that validation is called every `frequency` epochs.
361
- Valid values are: "disable", "warn", "error".
362
+ Root directory for this project.
363
+
364
+ This isn't specific to the current run; it is the parent directory of all runs.
362
365
  """
363
366
 
367
+ logdir_basename: str = DEFAULT_LOGDIR_BASENAME
368
+ """Base name for the log directory."""
369
+
370
+ setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
371
+ """Configuration for the directory setup PyTorch Lightning callback."""
372
+
373
+ def resolve_run_root_directory(self, run_id: str) -> Path:
374
+ if (project_root_dir := self.project_root) is None:
375
+ project_root_dir = Path.cwd()
376
+
377
+ # The default base dir is $CWD/{logdir_basename}/{id}/
378
+ base_dir = project_root_dir / self.logdir_basename
379
+ base_dir.mkdir(exist_ok=True)
380
+
381
+ # Add a .gitignore file to the {logdir_basename} directory
382
+ # which will ignore all files except for the .gitignore file itself
383
+ gitignore_path = base_dir / ".gitignore"
384
+ if not gitignore_path.exists():
385
+ gitignore_path.touch()
386
+ gitignore_path.write_text("*\n")
387
+
388
+ base_dir = base_dir / run_id
389
+ base_dir.mkdir(exist_ok=True)
390
+
391
+ return base_dir
392
+
393
+ def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
394
+ # The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
395
+ if (subdir := getattr(self, subdirectory, None)) is not None:
396
+ assert isinstance(subdir, Path), (
397
+ f"Expected a Path for {subdirectory}, got {type(subdir)}"
398
+ )
399
+ return subdir
400
+
401
+ dir = self.resolve_run_root_directory(run_id)
402
+ dir = dir / subdirectory
403
+ dir.mkdir(exist_ok=True)
404
+ return dir
405
+
406
+ def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
407
+ if (log_dir := logger.log_dir) is not None:
408
+ return log_dir
409
+
410
+ # Save to {logdir_basename}/{id}/log/{logger name}
411
+ log_dir = self.resolve_subdirectory(run_id, "log")
412
+ log_dir = log_dir / logger.resolve_logger_dirname()
413
+ # ^ NOTE: Logger must have a `name` attribute, as this is
414
+ # the discriminator for the logger registry
415
+ log_dir.mkdir(exist_ok=True)
416
+
417
+ return log_dir
418
+
364
419
 
365
420
  class TrainerConfig(C.Config):
366
421
  # region Active Run Configuration
367
- id: str = C.Field(default_factory=lambda: TrainerConfig.generate_id())
422
+ id: Annotated[str, C.AllowMissing()] = C.MISSING
368
423
  """ID of the run."""
369
424
  name: list[str] = []
370
425
  """Run name in parts. Full name is constructed by joining the parts with spaces."""
@@ -393,39 +448,6 @@ class TrainerConfig(C.Config):
393
448
 
394
449
  directory: DirectoryConfig = DirectoryConfig()
395
450
  """Directory configuration options."""
396
-
397
- _rng: ClassVar[np.random.Generator | None] = None
398
-
399
- @classmethod
400
- def generate_id(cls, *, length: int = 8) -> str:
401
- """
402
- Generate a random ID of specified length.
403
-
404
- """
405
- if (rng := cls._rng) is None:
406
- rng = np.random.default_rng()
407
-
408
- alphabet = list(string.ascii_lowercase + string.digits)
409
-
410
- id = "".join(rng.choice(alphabet) for _ in range(length))
411
- return id
412
-
413
- @classmethod
414
- def set_seed(cls, seed: int | None = None) -> None:
415
- """
416
- Set the seed for the random number generator.
417
-
418
- Args:
419
- seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
420
-
421
- Returns:
422
- None
423
- """
424
- if seed is None:
425
- seed = int(time.time() * 1000)
426
- log.critical(f"Seeding {cls.__name__} with seed {seed}")
427
- cls._rng = np.random.default_rng(seed)
428
-
429
451
  # endregion
430
452
 
431
453
  primary_metric: MetricConfig | None = None
@@ -755,40 +777,40 @@ class TrainerConfig(C.Config):
755
777
  None,
756
778
  )
757
779
 
758
- def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
759
- yield self.directory.setup_callback
760
- yield self.early_stopping
761
- yield self.checkpoint_saving
762
- yield self.lr_monitor
763
- yield from (
764
- logger_config
765
- for logger_config in self.enabled_loggers()
766
- if logger_config is not None
767
- and isinstance(logger_config, CallbackConfigBase)
768
- )
769
- yield self.log_epoch
770
- yield self.log_norms
771
- yield self.hf_hub
772
- yield self.shared_parameters
773
- yield self.reduce_lr_on_plateau_sanity_checking
774
- yield self.auto_set_debug_flag
775
- yield self.auto_validate_metrics
776
- yield from self.callbacks
780
+ # region Helper Methods
781
+ def id_(self, value: str):
782
+ """
783
+ Set the id for the trainer configuration in-place.
777
784
 
778
- def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
779
- # Disable all loggers if barebones mode is enabled
780
- if self.barebones:
781
- return
785
+ Parameters
786
+ ----------
787
+ value : str
788
+ The id value to set
782
789
 
783
- yield from self.enabled_loggers()
784
- yield self.actsave_logger
790
+ Returns
791
+ -------
792
+ self
793
+ Returns self for method chaining
794
+ """
795
+ self.id = value
796
+ return self
785
797
 
786
- def _nshtrainer_validate_before_run(self):
787
- # shared_parameters is not supported under barebones mode
788
- if self.barebones and self.shared_parameters:
789
- raise ValueError("shared_parameters is not supported under barebones mode")
798
+ def with_id(self, value: str):
799
+ """
800
+ Create a copy of the current configuration with an updated id.
801
+
802
+ Parameters
803
+ ----------
804
+ value : str
805
+ The id value to set
806
+
807
+ Returns
808
+ -------
809
+ TrainerConfig
810
+ A new instance of the configuration with the updated id
811
+ """
812
+ return copy.deepcopy(self).id_(value)
790
813
 
791
- # region Helper Methods
792
814
  def fast_dev_run_(self, value: int | bool = True, /):
793
815
  """
794
816
  Enables fast_dev_run mode for the trainer.
@@ -831,6 +853,349 @@ class TrainerConfig(C.Config):
831
853
  """
832
854
  return copy.deepcopy(self).project_root_(project_root)
833
855
 
856
+ def name_(self, *parts: str):
857
+ """
858
+ Set the name for the trainer configuration in-place.
859
+
860
+ Parameters
861
+ ----------
862
+ *parts : str
863
+ The parts of the name to set. Will be joined with spaces.
864
+
865
+ Returns
866
+ -------
867
+ self
868
+ Returns self for method chaining
869
+ """
870
+ self.name = list(parts)
871
+ return self
872
+
873
+ def with_name(self, *parts: str):
874
+ """
875
+ Create a copy of the current configuration with an updated name.
876
+
877
+ Parameters
878
+ ----------
879
+ *parts : str
880
+ The parts of the name to set. Will be joined with spaces.
881
+
882
+ Returns
883
+ -------
884
+ TrainerConfig
885
+ A new instance of the configuration with the updated name
886
+ """
887
+ return copy.deepcopy(self).name_(*parts)
888
+
889
+ def project_(self, project: str | None):
890
+ """
891
+ Set the project name for the trainer configuration in-place.
892
+
893
+ Parameters
894
+ ----------
895
+ project : str | None
896
+ The project name to set
897
+
898
+ Returns
899
+ -------
900
+ self
901
+ Returns self for method chaining
902
+ """
903
+ self.project = project
904
+ return self
905
+
906
+ def with_project(self, project: str | None):
907
+ """
908
+ Create a copy of the current configuration with an updated project name.
909
+
910
+ Parameters
911
+ ----------
912
+ project : str | None
913
+ The project name to set
914
+
915
+ Returns
916
+ -------
917
+ TrainerConfig
918
+ A new instance of the configuration with the updated project name
919
+ """
920
+ return copy.deepcopy(self).project_(project)
921
+
922
+ def tags_(self, *tags: str):
923
+ """
924
+ Set the tags for the trainer configuration in-place.
925
+
926
+ Parameters
927
+ ----------
928
+ *tags : str
929
+ The tags to set
930
+
931
+ Returns
932
+ -------
933
+ self
934
+ Returns self for method chaining
935
+ """
936
+ self.tags = list(tags)
937
+ return self
938
+
939
+ def with_tags(self, *tags: str):
940
+ """
941
+ Create a copy of the current configuration with updated tags.
942
+
943
+ Parameters
944
+ ----------
945
+ *tags : str
946
+ The tags to set
947
+
948
+ Returns
949
+ -------
950
+ TrainerConfig
951
+ A new instance of the configuration with the updated tags
952
+ """
953
+ return copy.deepcopy(self).tags_(*tags)
954
+
955
+ def add_tags_(self, *tags: str):
956
+ """
957
+ Add tags to the trainer configuration in-place.
958
+
959
+ Parameters
960
+ ----------
961
+ *tags : str
962
+ The tags to add
963
+
964
+ Returns
965
+ -------
966
+ self
967
+ Returns self for method chaining
968
+ """
969
+ self.tags.extend(tags)
970
+ return self
971
+
972
+ def with_added_tags(self, *tags: str):
973
+ """
974
+ Create a copy of the current configuration with additional tags.
975
+
976
+ Parameters
977
+ ----------
978
+ *tags : str
979
+ The tags to add
980
+
981
+ Returns
982
+ -------
983
+ TrainerConfig
984
+ A new instance of the configuration with the additional tags
985
+ """
986
+ return copy.deepcopy(self).add_tags_(*tags)
987
+
988
+ def notes_(self, *notes: str):
989
+ """
990
+ Set the notes for the trainer configuration in-place.
991
+
992
+ Parameters
993
+ ----------
994
+ *notes : str
995
+ The notes to set
996
+
997
+ Returns
998
+ -------
999
+ self
1000
+ Returns self for method chaining
1001
+ """
1002
+ self.notes = list(notes)
1003
+ return self
1004
+
1005
+ def with_notes(self, *notes: str):
1006
+ """
1007
+ Create a copy of the current configuration with updated notes.
1008
+
1009
+ Parameters
1010
+ ----------
1011
+ *notes : str
1012
+ The notes to set
1013
+
1014
+ Returns
1015
+ -------
1016
+ TrainerConfig
1017
+ A new instance of the configuration with the updated notes
1018
+ """
1019
+ return copy.deepcopy(self).notes_(*notes)
1020
+
1021
+ def add_notes_(self, *notes: str):
1022
+ """
1023
+ Add notes to the trainer configuration in-place.
1024
+
1025
+ Parameters
1026
+ ----------
1027
+ *notes : str
1028
+ The notes to add
1029
+
1030
+ Returns
1031
+ -------
1032
+ self
1033
+ Returns self for method chaining
1034
+ """
1035
+ self.notes.extend(notes)
1036
+ return self
1037
+
1038
+ def with_added_notes(self, *notes: str):
1039
+ """
1040
+ Create a copy of the current configuration with additional notes.
1041
+
1042
+ Parameters
1043
+ ----------
1044
+ *notes : str
1045
+ The notes to add
1046
+
1047
+ Returns
1048
+ -------
1049
+ TrainerConfig
1050
+ A new instance of the configuration with the additional notes
1051
+ """
1052
+ return copy.deepcopy(self).add_notes_(*notes)
1053
+
1054
+ def meta_(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
1055
+ """
1056
+ Update the `meta` dictionary in-place with the provided key-value pairs.
1057
+
1058
+ This method allows updating the meta information associated with the trainer
1059
+ configuration by either passing a dictionary or keyword arguments.
1060
+
1061
+ Parameters
1062
+ ----------
1063
+ meta : dict[str, Any] | None, optional
1064
+ A dictionary containing meta information to be added, by default None
1065
+ **kwargs : Any
1066
+ Additional key-value pairs to be added to the meta dictionary
1067
+
1068
+ Returns
1069
+ -------
1070
+ self
1071
+ Returns self for method chaining
1072
+ """
1073
+ if meta is not None:
1074
+ self.meta.update(meta)
1075
+ self.meta.update(kwargs)
1076
+ return self
1077
+
1078
+ def with_meta(self, meta: dict[str, Any] | None = None, /, **kwargs: Any):
1079
+ """
1080
+ Create a copy of the current configuration with updated meta information.
1081
+
1082
+ This method is similar to `meta_`, but it returns a new instance of the configuration
1083
+ with the updated meta information instead of modifying the current instance.
1084
+
1085
+ Parameters
1086
+ ----------
1087
+ meta : dict[str, Any] | None, optional
1088
+ A dictionary containing meta information to be added, by default None
1089
+ **kwargs : Any
1090
+ Additional key-value pairs to be added to the meta dictionary
1091
+
1092
+ Returns
1093
+ -------
1094
+ TrainerConfig
1095
+ A new instance of the configuration with updated meta information
1096
+ """
1097
+
1098
+ return self.model_copy(deep=True).meta_(meta, **kwargs)
1099
+
1100
+ def debug_(self, value: bool = True):
1101
+ """
1102
+ Set the debug flag for the trainer configuration in-place.
1103
+
1104
+ Parameters
1105
+ ----------
1106
+ value : bool, optional
1107
+ The debug flag value to set, by default True
1108
+
1109
+ Returns
1110
+ -------
1111
+ self
1112
+ Returns self for method chaining
1113
+ """
1114
+ self.debug = value
1115
+ return self
1116
+
1117
+ def with_debug(self, value: bool = True):
1118
+ """
1119
+ Create a copy of the current configuration with an updated debug flag.
1120
+
1121
+ Parameters
1122
+ ----------
1123
+ value : bool, optional
1124
+ The debug flag value to set, by default True
1125
+
1126
+ Returns
1127
+ -------
1128
+ TrainerConfig
1129
+ A new instance of the configuration with the updated debug flag
1130
+ """
1131
+ return copy.deepcopy(self).debug_(value)
1132
+
1133
+ def ckpt_path_(self, path: Literal["none"] | str | Path | None):
1134
+ """
1135
+ Set the checkpoint path for the trainer configuration in-place.
1136
+
1137
+ Parameters
1138
+ ----------
1139
+ path : Literal["none"] | str | Path | None
1140
+ The checkpoint path to set
1141
+
1142
+ Returns
1143
+ -------
1144
+ self
1145
+ Returns self for method chaining
1146
+ """
1147
+ self.ckpt_path = path
1148
+ return self
1149
+
1150
+ def with_ckpt_path(self, path: Literal["none"] | str | Path | None):
1151
+ """
1152
+ Create a copy of the current configuration with an updated checkpoint path.
1153
+
1154
+ Parameters
1155
+ ----------
1156
+ path : Literal["none"] | str | Path | None
1157
+ The checkpoint path to set
1158
+
1159
+ Returns
1160
+ -------
1161
+ TrainerConfig
1162
+ A new instance of the configuration with the updated checkpoint path
1163
+ """
1164
+ return copy.deepcopy(self).ckpt_path_(path)
1165
+
1166
+ def barebones_(self, value: bool = True):
1167
+ """
1168
+ Set the barebones flag for the trainer configuration in-place.
1169
+
1170
+ Parameters
1171
+ ----------
1172
+ value : bool, optional
1173
+ The barebones flag value to set, by default True
1174
+
1175
+ Returns
1176
+ -------
1177
+ self
1178
+ Returns self for method chaining
1179
+ """
1180
+ self.barebones = value
1181
+ return self
1182
+
1183
+ def with_barebones(self, value: bool = True):
1184
+ """
1185
+ Create a copy of the current configuration with an updated barebones flag.
1186
+
1187
+ Parameters
1188
+ ----------
1189
+ value : bool, optional
1190
+ The barebones flag value to set, by default True
1191
+
1192
+ Returns
1193
+ -------
1194
+ TrainerConfig
1195
+ A new instance of the configuration with the updated barebones flag
1196
+ """
1197
+ return copy.deepcopy(self).barebones_(value)
1198
+
834
1199
  def reset_run(
835
1200
  self,
836
1201
  *,
@@ -873,3 +1238,84 @@ class TrainerConfig(C.Config):
873
1238
  return config
874
1239
 
875
1240
  # endregion
1241
+
1242
+ # region Random ID Generation
1243
+ _rng: ClassVar[np.random.Generator | None] = None
1244
+
1245
+ @classmethod
1246
+ def generate_id(cls, *, length: int = 8) -> str:
1247
+ """
1248
+ Generate a random ID of specified length.
1249
+
1250
+ """
1251
+ if (rng := cls._rng) is None:
1252
+ rng = np.random.default_rng()
1253
+
1254
+ alphabet = list(string.ascii_lowercase + string.digits)
1255
+
1256
+ id = "".join(rng.choice(alphabet) for _ in range(length))
1257
+ return id
1258
+
1259
+ @classmethod
1260
+ def set_seed(cls, seed: int | None = None) -> None:
1261
+ """
1262
+ Set the seed for the random number generator.
1263
+
1264
+ Args:
1265
+ seed (int | None, optional): The seed value to set. If None, a seed based on the current time will be used. Defaults to None.
1266
+
1267
+ Returns:
1268
+ None
1269
+ """
1270
+ if seed is None:
1271
+ seed = int(time.time() * 1000)
1272
+ log.critical(f"Seeding {cls.__name__} with seed {seed}")
1273
+ cls._rng = np.random.default_rng(seed)
1274
+
1275
+ # endregion
1276
+
1277
+ # region Internal Methods
1278
+ def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
1279
+ yield self.directory.setup_callback
1280
+ yield self.early_stopping
1281
+ yield self.checkpoint_saving
1282
+ yield self.lr_monitor
1283
+ yield from (
1284
+ logger_config
1285
+ for logger_config in self.enabled_loggers()
1286
+ if logger_config is not None
1287
+ and isinstance(logger_config, CallbackConfigBase)
1288
+ )
1289
+ yield self.log_epoch
1290
+ yield self.log_norms
1291
+ yield self.hf_hub
1292
+ yield self.shared_parameters
1293
+ yield self.reduce_lr_on_plateau_sanity_checking
1294
+ yield self.auto_set_debug_flag
1295
+ yield self.auto_validate_metrics
1296
+ yield from self.callbacks
1297
+
1298
+ def _nshtrainer_all_logger_configs(self) -> Iterable[LoggerConfigBase | None]:
1299
+ # Disable all loggers if barebones mode is enabled
1300
+ if self.barebones:
1301
+ return
1302
+
1303
+ yield from self.enabled_loggers()
1304
+ yield self.actsave_logger
1305
+
1306
+ def _nshtrainer_validate_before_run(self):
1307
+ # shared_parameters is not supported under barebones mode
1308
+ if self.barebones and self.shared_parameters:
1309
+ raise ValueError("shared_parameters is not supported under barebones mode")
1310
+
1311
+ def _nshtrainer_set_id_if_missing(self):
1312
+ """
1313
+ Set the ID for the configuration object if it is missing.
1314
+ """
1315
+ if self.id is C.MISSING:
1316
+ self.id = self.generate_id()
1317
+ log.info(f"TrainerConfig's run ID is missing, setting to {self.id}.")
1318
+ else:
1319
+ log.debug(f"TrainerConfig's run ID is already set to {self.id}.")
1320
+
1321
+ # endregion
@@ -14,7 +14,6 @@ from pathlib import Path
14
14
  from types import FrameType
15
15
  from typing import Any
16
16
 
17
- import nshrunner as nr
18
17
  import torch.utils.data
19
18
  from lightning.fabric.plugins.environments.lsf import LSFEnvironment
20
19
  from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
@@ -34,6 +33,12 @@ _IS_WINDOWS = platform.system() == "Windows"
34
33
 
35
34
 
36
35
  def _resolve_requeue_signals():
36
+ try:
37
+ import nshrunner as nr
38
+ except ImportError:
39
+ log.debug("nshrunner not found. Skipping signal requeueing.")
40
+ return None
41
+
37
42
  if (session := nr.Session.from_current_session()) is None:
38
43
  return None
39
44
 
@@ -52,9 +57,9 @@ class _SignalConnector(_LightningSignalConnector):
52
57
 
53
58
  signals_set = set(signals)
54
59
  valid_signals: set[signal.Signals] = signal.valid_signals()
55
- assert signals_set.issubset(
56
- valid_signals
57
- ), f"Invalid signal(s) found: {signals_set - valid_signals}"
60
+ assert signals_set.issubset(valid_signals), (
61
+ f"Invalid signal(s) found: {signals_set - valid_signals}"
62
+ )
58
63
  return signals
59
64
 
60
65
  def _compose_and_register(
@@ -241,9 +246,9 @@ class _SignalConnector(_LightningSignalConnector):
241
246
  "Writing requeue script to exit script directory."
242
247
  )
243
248
  exit_script_dir = Path(exit_script_dir)
244
- assert (
245
- exit_script_dir.is_dir()
246
- ), f"Exit script directory {exit_script_dir} does not exist"
249
+ assert exit_script_dir.is_dir(), (
250
+ f"Exit script directory {exit_script_dir} does not exist"
251
+ )
247
252
 
248
253
  exit_script_path = exit_script_dir / f"requeue_{job_id}.sh"
249
254
  log.info(f"Writing requeue script to {exit_script_path}")
@@ -316,6 +316,7 @@ class Trainer(LightningTrainer):
316
316
  f"Trainer hparams must either be an instance of {hparams_cls} or a mapping. "
317
317
  f"Got {type(hparams)=} instead."
318
318
  )
319
+ hparams._nshtrainer_set_id_if_missing()
319
320
  hparams = hparams.model_deep_validate()
320
321
  hparams._nshtrainer_validate_before_run()
321
322
 
@@ -356,12 +356,20 @@ class EnvironmentSnapshotConfig(C.Config):
356
356
 
357
357
  @classmethod
358
358
  def from_current_environment(cls):
359
- draft = cls.draft()
360
- if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
361
- draft.snapshot_dir = Path(snapshot_dir)
362
- if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
363
- draft.modules = modules.split(",")
364
- return draft.finalize()
359
+ try:
360
+ import nshrunner as nr
361
+
362
+ if (session := nr.Session.from_current_session()) is None:
363
+ log.warning("No active session found, skipping snapshot information")
364
+ return cls.empty()
365
+
366
+ draft = cls.draft()
367
+ draft.snapshot_dir = session.snapshot_dir
368
+ draft.modules = session.snapshot_modules
369
+ return draft.finalize()
370
+ except ImportError:
371
+ log.warning("nshrunner not found, skipping snapshot information")
372
+ return cls.empty()
365
373
 
366
374
 
367
375
  class EnvironmentPackageConfig(C.Config):
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ log = logging.getLogger(__name__)
7
+
8
+
9
+ def get_code_dir() -> Path | None:
10
+ try:
11
+ import nshrunner as nr
12
+
13
+ if (session := nr.Session.from_current_session()) is None:
14
+ log.debug("No active session found. Skipping code upload.")
15
+ return None
16
+
17
+ # New versions of nshrunner will have the code_dir attribute
18
+ # in the session object. We should use that. Otherwise, use snapshot_dir.
19
+ try:
20
+ code_dir = session.code_dir # type: ignore
21
+ except AttributeError:
22
+ code_dir = session.snapshot_dir
23
+
24
+ if code_dir is None:
25
+ log.debug("No code directory found. Skipping code upload.")
26
+ return None
27
+
28
+ assert isinstance(code_dir, Path), (
29
+ f"Code directory should be a Path object. Got {type(code_dir)} instead."
30
+ )
31
+ if not code_dir.exists() or not code_dir.is_dir():
32
+ log.warning(
33
+ f"Code directory '{code_dir}' does not exist or is not a directory."
34
+ )
35
+ return None
36
+
37
+ return code_dir
38
+ except ImportError:
39
+ log.debug("nshrunner not found. Skipping code upload.")
40
+ return None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.3.4
3
+ Version: 1.3.6
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -15,7 +15,7 @@ Requires-Dist: GitPython ; extra == "extra"
15
15
  Requires-Dist: huggingface-hub ; extra == "extra"
16
16
  Requires-Dist: lightning
17
17
  Requires-Dist: nshconfig (>0.39)
18
- Requires-Dist: nshrunner
18
+ Requires-Dist: nshrunner ; extra == "extra"
19
19
  Requires-Dist: nshutils ; extra == "extra"
20
20
  Requires-Dist: numpy
21
21
  Requires-Dist: packaging
@@ -1,11 +1,10 @@
1
1
  nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
2
- nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
2
+ nshtrainer/__init__.py,sha256=RI_2B_IUWa10B6H5TAuWtE5FWX1X4ue-J4dTDaF2-lQ,1035
3
3
  nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
5
5
  nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
6
- nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
7
6
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
- nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
7
+ nshtrainer/_hf_hub.py,sha256=kfN0wDxK5JWKKGZnX_706i0KXGhaS19p581LDTPxlRE,13996
9
8
  nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
10
9
  nshtrainer/callbacks/actsave.py,sha256=NSXIIu62MNYe5gz479SMW33bdoKYoYtWtd_iTWFpKpc,3881
11
10
  nshtrainer/callbacks/base.py,sha256=K9aom1WVVRYxl-tHWgtmDUQZ1o63NgznvLsjauTKcCc,4225
@@ -30,13 +29,12 @@ nshtrainer/callbacks/print_table.py,sha256=VaS4JgI963do79laXK4lUkFQx8v6aRSy22W0z
30
29
  nshtrainer/callbacks/rlp_sanity_checks.py,sha256=Df9Prq2QKXnaeMBIvMQBhDhJTDeru5UbiuXJOJR16Gk,10050
31
30
  nshtrainer/callbacks/shared_parameters.py,sha256=s94jJTAIbDGukYJu6l247QonVOCudGClU4t5kLt8XrY,3076
32
31
  nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU,4731
33
- nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
32
+ nshtrainer/callbacks/wandb_upload_code.py,sha256=4X-mpiX5ghj9vnEreK2i8Xyvimqt0K-PNWA2HtT-B6I,1940
34
33
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
35
34
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
36
- nshtrainer/configs/__init__.py,sha256=KD3uClMwnA4LfQ7rY5phDdUbp3j8NoZfaGbGPbpaJVs,15848
35
+ nshtrainer/configs/__init__.py,sha256=-yJ5Uk9VkANqfk-QnX2aynL0jSf7cJQuQNzT1GAE1x8,15684
37
36
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
38
37
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
39
- nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
40
38
  nshtrainer/configs/_hf_hub/__init__.py,sha256=ciFLbV-JV8SVzqo2SyythEuDMnk7gGfdIacB18QYnkY,511
41
39
  nshtrainer/configs/callbacks/__init__.py,sha256=tP9urR73NIanyxpbi4EERsxOnGNiptbQpmsj-v53a38,4774
42
40
  nshtrainer/configs/callbacks/actsave/__init__.py,sha256=JvjSZtEoA28FC4u-QT3skQzBDVbN9eq07rn4u2ydW-E,377
@@ -85,8 +83,8 @@ nshtrainer/configs/profiler/_base/__init__.py,sha256=ekYfPg-VDhCAFM5nJka2TxUYdRD
85
83
  nshtrainer/configs/profiler/advanced/__init__.py,sha256=-ThpUat16Ij_0avkMUVVA8wCWDG_q_tM7KQofnWQCtg,308
86
84
  nshtrainer/configs/profiler/pytorch/__init__.py,sha256=soAU1s2_Pa1na4gW8CK-iysJBO5M_7YeZC2_x40iEdg,294
87
85
  nshtrainer/configs/profiler/simple/__init__.py,sha256=3Wb11lPuFuyasq8xS1CZ4WLuBCLS_nVSQGVllvOOi0Y,289
88
- nshtrainer/configs/trainer/__init__.py,sha256=YLlDOUYDp_qURHhcmhCxTcY6K5AbmoTxdzBPB9SEZII,8040
89
- nshtrainer/configs/trainer/_config/__init__.py,sha256=6DXdtP-uH11TopQ7kzId9fco-wVkD7ZfevbBqDpN6TE,3817
86
+ nshtrainer/configs/trainer/__init__.py,sha256=DM2PlB4WRDZ_dqEeW91LbKRFa4sIF_pETU0T9GYJ5-g,8073
87
+ nshtrainer/configs/trainer/_config/__init__.py,sha256=z5UpuXktBanLOYNkkbgbbHE06iQtcSuAKTpnx2TLmCo,3850
90
88
  nshtrainer/configs/trainer/accelerator/__init__.py,sha256=3H6R3wlwbKL1TzDqGCChZk78-BcE2czLouo7Djiq3nA,898
91
89
  nshtrainer/configs/trainer/plugin/__init__.py,sha256=NkHQxMPkrtTtdIAO4dQUE9SWEcHRDB0yUXLkTjnl4dA,3332
92
90
  nshtrainer/configs/trainer/plugin/base/__init__.py,sha256=slW5z1FZw2qICXO9l9DnLIDB1Yl7KOcxPEZkyYIHrp4,276
@@ -135,7 +133,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
135
133
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
136
134
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
137
135
  nshtrainer/trainer/__init__.py,sha256=jRaHdaFK8wxNrN1bleT9cf29iZahL_-XkWo5TWz2CmA,550
138
- nshtrainer/trainer/_config.py,sha256=SohR7uxANnP3xrrcW_mAjk6TuDamsW5Qdk3dlnPinDw,33457
136
+ nshtrainer/trainer/_config.py,sha256=x3YjP_0IykqSRh8YIilCxq3nPt_fZXoVcxfR13ulmV0,45578
139
137
  nshtrainer/trainer/_distributed_prediction_result.py,sha256=bQw8Z6PT694UUf-zQPkech6CxyUSy8bAIexfSfPej0U,2507
140
138
  nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
141
139
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
@@ -146,11 +144,12 @@ nshtrainer/trainer/plugin/environment.py,sha256=SSXRWHjyFUA6oFx3duD_ZwhM59pWUjR1
146
144
  nshtrainer/trainer/plugin/io.py,sha256=OmFSKLloMypletjaUr_Ptg6LS0ljqTVIp2o4Hm3eZoE,1926
147
145
  nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLcOfPXnvH29s,663
148
146
  nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
149
- nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
147
+ nshtrainer/trainer/signal_connector.py,sha256=ZgbSkbthoe8MYN6rBoFf-7UDpQtc9fs9pG_FNvTYSfs,10962
150
148
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
151
- nshtrainer/trainer/trainer.py,sha256=6oky6E8cjGqUNzJGyyTO551pE9A6YueOv5oxg1fZVR0,24129
152
- nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
149
+ nshtrainer/trainer/trainer.py,sha256=iQWu0KfwLY-1q9EEsg0xPlyUN1fsJ9iXfSQbPmiMlac,24177
150
+ nshtrainer/util/_environment_info.py,sha256=j-wyEHKirsu3rIXTtqC2kLmIIkRe6obWjxPVWaqg2ow,24887
153
151
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
152
+ nshtrainer/util/code_upload.py,sha256=CpbZEBbA8EcBElUVoCPbP5zdwtNzJhS20RLaOB-q-2k,1257
154
153
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
155
154
  nshtrainer/util/config/dtype.py,sha256=Fn_MhhQoHPyFAnFPSwvcvLiGR3yWFIszMba02CJiC4g,2213
156
155
  nshtrainer/util/config/duration.py,sha256=mM-UfU_HvhXwW33TYEDg0x58n80tnle2e6VaWtxZTjk,764
@@ -160,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
160
159
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
161
160
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
162
161
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
163
- nshtrainer-1.3.4.dist-info/METADATA,sha256=Dm6wgfQh8ZC42IeftejuUZ-KZ2YBWBjnpHa_pYNi7Kc,960
164
- nshtrainer-1.3.4.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
165
- nshtrainer-1.3.4.dist-info/RECORD,,
162
+ nshtrainer-1.3.6.dist-info/METADATA,sha256=KidLM7J5P7mALTfYQsveSDjAOJZ-Gcq4_e1-Xgrms68,979
163
+ nshtrainer-1.3.6.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
+ nshtrainer-1.3.6.dist-info/RECORD,,
nshtrainer/_directory.py DELETED
@@ -1,72 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import logging
4
- from pathlib import Path
5
-
6
- import nshconfig as C
7
-
8
- from .callbacks.directory_setup import DirectorySetupCallbackConfig
9
- from .loggers import LoggerConfig
10
-
11
- log = logging.getLogger(__name__)
12
-
13
-
14
- class DirectoryConfig(C.Config):
15
- project_root: Path | None = None
16
- """
17
- Root directory for this project.
18
-
19
- This isn't specific to the run; it is the parent directory of all runs.
20
- """
21
-
22
- logdir_basename: str = "nshtrainer"
23
- """Base name for the log directory."""
24
-
25
- setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
26
- """Configuration for the directory setup PyTorch Lightning callback."""
27
-
28
- def resolve_run_root_directory(self, run_id: str) -> Path:
29
- if (project_root_dir := self.project_root) is None:
30
- project_root_dir = Path.cwd()
31
-
32
- # The default base dir is $CWD/{logdir_basename}/{id}/
33
- base_dir = project_root_dir / self.logdir_basename
34
- base_dir.mkdir(exist_ok=True)
35
-
36
- # Add a .gitignore file to the {logdir_basename} directory
37
- # which will ignore all files except for the .gitignore file itself
38
- gitignore_path = base_dir / ".gitignore"
39
- if not gitignore_path.exists():
40
- gitignore_path.touch()
41
- gitignore_path.write_text("*\n")
42
-
43
- base_dir = base_dir / run_id
44
- base_dir.mkdir(exist_ok=True)
45
-
46
- return base_dir
47
-
48
- def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
49
- # The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
50
- if (subdir := getattr(self, subdirectory, None)) is not None:
51
- assert isinstance(subdir, Path), (
52
- f"Expected a Path for {subdirectory}, got {type(subdir)}"
53
- )
54
- return subdir
55
-
56
- dir = self.resolve_run_root_directory(run_id)
57
- dir = dir / subdirectory
58
- dir.mkdir(exist_ok=True)
59
- return dir
60
-
61
- def _resolve_log_directory_for_logger(self, run_id: str, logger: LoggerConfig):
62
- if (log_dir := logger.log_dir) is not None:
63
- return log_dir
64
-
65
- # Save to {logdir_basename}/{id}/log/{logger name}
66
- log_dir = self.resolve_subdirectory(run_id, "log")
67
- log_dir = log_dir / logger.resolve_logger_dirname()
68
- # ^ NOTE: Logger must have a `name` attribute, as this is
69
- # the discriminator for the logger registry
70
- log_dir.mkdir(exist_ok=True)
71
-
72
- return log_dir
@@ -1,15 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __codegen__ = True
4
-
5
- from nshtrainer._directory import DirectoryConfig as DirectoryConfig
6
- from nshtrainer._directory import (
7
- DirectorySetupCallbackConfig as DirectorySetupCallbackConfig,
8
- )
9
- from nshtrainer._directory import LoggerConfig as LoggerConfig
10
-
11
- __all__ = [
12
- "DirectoryConfig",
13
- "DirectorySetupCallbackConfig",
14
- "LoggerConfig",
15
- ]