nshtrainer 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_directory.py +8 -25
- nshtrainer/callbacks/directory_setup.py +15 -8
- nshtrainer/trainer/_config.py +1 -0
- {nshtrainer-1.2.0.dist-info → nshtrainer-1.2.1.dist-info}/METADATA +1 -1
- {nshtrainer-1.2.0.dist-info → nshtrainer-1.2.1.dist-info}/RECORD +6 -6
- {nshtrainer-1.2.0.dist-info → nshtrainer-1.2.1.dist-info}/WHEEL +0 -0
nshtrainer/_directory.py
CHANGED
@@ -19,20 +19,8 @@ class DirectoryConfig(C.Config):
|
|
19
19
|
This isn't specific to the run; it is the parent directory of all runs.
|
20
20
|
"""
|
21
21
|
|
22
|
-
|
23
|
-
"""Base
|
24
|
-
|
25
|
-
stdio: Path | None = None
|
26
|
-
"""stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
|
27
|
-
|
28
|
-
checkpoint: Path | None = None
|
29
|
-
"""Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
|
30
|
-
|
31
|
-
activation: Path | None = None
|
32
|
-
"""Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
|
33
|
-
|
34
|
-
profile: Path | None = None
|
35
|
-
"""Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
|
22
|
+
logdir_basename: str = "nshtrainer"
|
23
|
+
"""Base name for the log directory."""
|
36
24
|
|
37
25
|
setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
|
38
26
|
"""Configuration for the directory setup PyTorch Lightning callback."""
|
@@ -41,11 +29,11 @@ class DirectoryConfig(C.Config):
|
|
41
29
|
if (project_root_dir := self.project_root) is None:
|
42
30
|
project_root_dir = Path.cwd()
|
43
31
|
|
44
|
-
# The default base dir is $CWD/
|
45
|
-
base_dir = project_root_dir /
|
32
|
+
# The default base dir is $CWD/{logdir_basename}/{id}/
|
33
|
+
base_dir = project_root_dir / self.logdir_basename
|
46
34
|
base_dir.mkdir(exist_ok=True)
|
47
35
|
|
48
|
-
# Add a .gitignore file to the
|
36
|
+
# Add a .gitignore file to the {logdir_basename} directory
|
49
37
|
# which will ignore all files except for the .gitignore file itself
|
50
38
|
gitignore_path = base_dir / ".gitignore"
|
51
39
|
if not gitignore_path.exists():
|
@@ -57,13 +45,8 @@ class DirectoryConfig(C.Config):
|
|
57
45
|
|
58
46
|
return base_dir
|
59
47
|
|
60
|
-
def resolve_subdirectory(
|
61
|
-
|
62
|
-
run_id: str,
|
63
|
-
# subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
|
64
|
-
subdirectory: str,
|
65
|
-
) -> Path:
|
66
|
-
# The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
|
48
|
+
def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
|
49
|
+
# The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
|
67
50
|
if (subdir := getattr(self, subdirectory, None)) is not None:
|
68
51
|
assert isinstance(subdir, Path), (
|
69
52
|
f"Expected a Path for {subdirectory}, got {type(subdir)}"
|
@@ -79,7 +62,7 @@ class DirectoryConfig(C.Config):
|
|
79
62
|
if (log_dir := logger.log_dir) is not None:
|
80
63
|
return log_dir
|
81
64
|
|
82
|
-
# Save to
|
65
|
+
# Save to {logdir_basename}/{id}/log/{logger name}
|
83
66
|
log_dir = self.resolve_subdirectory(run_id, "log")
|
84
67
|
log_dir = log_dir / logger.resolve_logger_dirname()
|
85
68
|
# ^ NOTE: Logger must have a `name` attribute, as this is
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import os
|
5
4
|
from pathlib import Path
|
6
5
|
from typing import Literal
|
7
6
|
|
@@ -27,6 +26,7 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
|
|
27
26
|
def __bool__(self):
|
28
27
|
return self.enabled
|
29
28
|
|
29
|
+
@override
|
30
30
|
def create_callbacks(self, trainer_config):
|
31
31
|
if not self:
|
32
32
|
return
|
@@ -35,21 +35,28 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
|
|
35
35
|
|
36
36
|
|
37
37
|
def _create_symlink_to_nshrunner(base_dir: Path):
|
38
|
-
|
39
|
-
|
40
|
-
|
38
|
+
try:
|
39
|
+
import nshrunner as nr
|
40
|
+
except ImportError:
|
41
|
+
log.info("nshrunner is not installed. Skipping symlink creation to nshrunner.")
|
42
|
+
return
|
43
|
+
|
44
|
+
# Check if we are in a nshrunner session
|
45
|
+
if (session := nr.Session.from_current_session()) is None:
|
46
|
+
log.info("No current nshrunner session found. Skipping symlink creation.")
|
41
47
|
return
|
42
|
-
|
48
|
+
|
49
|
+
session_dir = session.session_dir
|
43
50
|
if not session_dir.exists() or not session_dir.is_dir():
|
44
51
|
log.warning(
|
45
|
-
f"
|
52
|
+
f"nshrunner's session_dir is not a valid directory: {session_dir}. "
|
46
53
|
"Skipping symlink creation."
|
47
54
|
)
|
48
55
|
return
|
49
56
|
|
50
57
|
# Create the symlink
|
51
58
|
symlink_path = base_dir / "nshrunner"
|
52
|
-
if symlink_path.exists():
|
59
|
+
if symlink_path.exists(follow_symlinks=False):
|
53
60
|
# If it already points to the correct directory, we're done
|
54
61
|
if symlink_path.resolve() == session_dir.resolve():
|
55
62
|
return
|
@@ -61,7 +68,7 @@ def _create_symlink_to_nshrunner(base_dir: Path):
|
|
61
68
|
)
|
62
69
|
symlink_path.unlink()
|
63
70
|
|
64
|
-
symlink_path.symlink_to(session_dir)
|
71
|
+
symlink_path.symlink_to(session_dir, target_is_directory=True)
|
65
72
|
|
66
73
|
|
67
74
|
class DirectorySetupCallback(NTCallbackBase):
|
nshtrainer/trainer/_config.py
CHANGED
@@ -761,6 +761,7 @@ class TrainerConfig(C.Config):
|
|
761
761
|
)
|
762
762
|
|
763
763
|
def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
|
764
|
+
yield self.directory.setup_callback
|
764
765
|
yield self.early_stopping
|
765
766
|
yield self.checkpoint_saving
|
766
767
|
yield self.lr_monitor
|
@@ -3,7 +3,7 @@ nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
|
|
3
3
|
nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
|
6
|
-
nshtrainer/_directory.py,sha256=
|
6
|
+
nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
9
9
|
nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
|
@@ -15,7 +15,7 @@ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7
|
|
15
15
|
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
17
17
|
nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
|
18
|
-
nshtrainer/callbacks/directory_setup.py,sha256=
|
18
|
+
nshtrainer/callbacks/directory_setup.py,sha256=Ln6f0tCgoBscHeigIAWtCCoAmuWB-kPyaf7SylU7MYo,2773
|
19
19
|
nshtrainer/callbacks/distributed_prediction_writer.py,sha256=OSh2C6XF7Nki4eFByNVhwlt69izkxnlmfPx54w4rvBo,5274
|
20
20
|
nshtrainer/callbacks/early_stopping.py,sha256=rC_qYKCQWjRQJFo0ky46uG0aDJdYP8vsSlKunk0bUVI,4765
|
21
21
|
nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
|
@@ -135,7 +135,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
|
|
135
135
|
nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
|
136
136
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
137
137
|
nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
|
138
|
-
nshtrainer/trainer/_config.py,sha256=
|
138
|
+
nshtrainer/trainer/_config.py,sha256=EaAZCajuvDoVAn79zGebUgg0ijz-i3eLhwxYq8oTNe8,33600
|
139
139
|
nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
|
140
140
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
141
141
|
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
@@ -159,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
159
159
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
160
160
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
161
161
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
162
|
-
nshtrainer-1.2.
|
163
|
-
nshtrainer-1.2.
|
164
|
-
nshtrainer-1.2.
|
162
|
+
nshtrainer-1.2.1.dist-info/METADATA,sha256=MrF6xvpgRjy2HsFwCIU9YZXMYcOb9ItAXY8m66P2CoQ,960
|
163
|
+
nshtrainer-1.2.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
164
|
+
nshtrainer-1.2.1.dist-info/RECORD,,
|
File without changes
|