nshtrainer 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_directory.py CHANGED
@@ -19,20 +19,8 @@ class DirectoryConfig(C.Config):
19
19
  This isn't specific to the run; it is the parent directory of all runs.
20
20
  """
21
21
 
22
- log: Path | None = None
23
- """Base directory for all experiment tracking (e.g., WandB, Tensorboard, etc.) files. If None, will use nshtrainer/{id}/log/."""
24
-
25
- stdio: Path | None = None
26
- """stdout/stderr log directory to use for the trainer. If None, will use nshtrainer/{id}/stdio/."""
27
-
28
- checkpoint: Path | None = None
29
- """Checkpoint directory to use for the trainer. If None, will use nshtrainer/{id}/checkpoint/."""
30
-
31
- activation: Path | None = None
32
- """Activation directory to use for the trainer. If None, will use nshtrainer/{id}/activation/."""
33
-
34
- profile: Path | None = None
35
- """Directory to save profiling information to. If None, will use nshtrainer/{id}/profile/."""
22
+ logdir_basename: str = "nshtrainer"
23
+ """Base name for the log directory."""
36
24
 
37
25
  setup_callback: DirectorySetupCallbackConfig = DirectorySetupCallbackConfig()
38
26
  """Configuration for the directory setup PyTorch Lightning callback."""
@@ -41,11 +29,11 @@ class DirectoryConfig(C.Config):
41
29
  if (project_root_dir := self.project_root) is None:
42
30
  project_root_dir = Path.cwd()
43
31
 
44
- # The default base dir is $CWD/nshtrainer/{id}/
45
- base_dir = project_root_dir / "nshtrainer"
32
+ # The default base dir is $CWD/{logdir_basename}/{id}/
33
+ base_dir = project_root_dir / self.logdir_basename
46
34
  base_dir.mkdir(exist_ok=True)
47
35
 
48
- # Add a .gitignore file to the nshtrainer directory
36
+ # Add a .gitignore file to the {logdir_basename} directory
49
37
  # which will ignore all files except for the .gitignore file itself
50
38
  gitignore_path = base_dir / ".gitignore"
51
39
  if not gitignore_path.exists():
@@ -57,13 +45,8 @@ class DirectoryConfig(C.Config):
57
45
 
58
46
  return base_dir
59
47
 
60
- def resolve_subdirectory(
61
- self,
62
- run_id: str,
63
- # subdirectory: Literal["log", "stdio", "checkpoint", "activation", "profile"],
64
- subdirectory: str,
65
- ) -> Path:
66
- # The subdir will be $CWD/nshtrainer/{id}/{log, stdio, checkpoint, activation}/
48
+ def resolve_subdirectory(self, run_id: str, subdirectory: str) -> Path:
49
+ # The subdir will be $CWD/{logdir_basename}/{id}/{log, stdio, checkpoint, activation}/
67
50
  if (subdir := getattr(self, subdirectory, None)) is not None:
68
51
  assert isinstance(subdir, Path), (
69
52
  f"Expected a Path for {subdirectory}, got {type(subdir)}"
@@ -79,7 +62,7 @@ class DirectoryConfig(C.Config):
79
62
  if (log_dir := logger.log_dir) is not None:
80
63
  return log_dir
81
64
 
82
- # Save to nshtrainer/{id}/log/{logger name}
65
+ # Save to {logdir_basename}/{id}/log/{logger name}
83
66
  log_dir = self.resolve_subdirectory(run_id, "log")
84
67
  log_dir = log_dir / logger.resolve_logger_dirname()
85
68
  # ^ NOTE: Logger must have a `name` attribute, as this is
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import os
5
4
  from pathlib import Path
6
5
  from typing import Literal
7
6
 
@@ -27,6 +26,7 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
27
26
  def __bool__(self):
28
27
  return self.enabled
29
28
 
29
+ @override
30
30
  def create_callbacks(self, trainer_config):
31
31
  if not self:
32
32
  return
@@ -35,21 +35,28 @@ class DirectorySetupCallbackConfig(CallbackConfigBase):
35
35
 
36
36
 
37
37
  def _create_symlink_to_nshrunner(base_dir: Path):
38
- # Resolve the current nshrunner session directory
39
- if not (session_dir := os.environ.get("NSHRUNNER_SESSION_DIR")):
40
- log.warning("NSHRUNNER_SESSION_DIR is not set. Skipping symlink creation.")
38
+ try:
39
+ import nshrunner as nr
40
+ except ImportError:
41
+ log.info("nshrunner is not installed. Skipping symlink creation to nshrunner.")
42
+ return
43
+
44
+ # Check if we are in a nshrunner session
45
+ if (session := nr.Session.from_current_session()) is None:
46
+ log.info("No current nshrunner session found. Skipping symlink creation.")
41
47
  return
42
- session_dir = Path(session_dir)
48
+
49
+ session_dir = session.session_dir
43
50
  if not session_dir.exists() or not session_dir.is_dir():
44
51
  log.warning(
45
- f"NSHRUNNER_SESSION_DIR is not a valid directory: {session_dir}. "
52
+ f"nshrunner's session_dir is not a valid directory: {session_dir}. "
46
53
  "Skipping symlink creation."
47
54
  )
48
55
  return
49
56
 
50
57
  # Create the symlink
51
58
  symlink_path = base_dir / "nshrunner"
52
- if symlink_path.exists():
59
+ if symlink_path.exists(follow_symlinks=False):
53
60
  # If it already points to the correct directory, we're done
54
61
  if symlink_path.resolve() == session_dir.resolve():
55
62
  return
@@ -61,7 +68,7 @@ def _create_symlink_to_nshrunner(base_dir: Path):
61
68
  )
62
69
  symlink_path.unlink()
63
70
 
64
- symlink_path.symlink_to(session_dir)
71
+ symlink_path.symlink_to(session_dir, target_is_directory=True)
65
72
 
66
73
 
67
74
  class DirectorySetupCallback(NTCallbackBase):
@@ -761,6 +761,7 @@ class TrainerConfig(C.Config):
761
761
  )
762
762
 
763
763
  def _nshtrainer_all_callback_configs(self) -> Iterable[CallbackConfigBase | None]:
764
+ yield self.directory.setup_callback
764
765
  yield self.early_stopping
765
766
  yield self.checkpoint_saving
766
767
  yield self.lr_monitor
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.2.0
3
+ Version: 1.2.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -3,7 +3,7 @@ nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
3
3
  nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
5
5
  nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
6
- nshtrainer/_directory.py,sha256=SuXJe9xJXZkDXWWfeOS9rEDz6vZUA6mpnEdkAW0ZQnY,3193
6
+ nshtrainer/_directory.py,sha256=RAG8e0y3VZwGIyy_D-GXgDMK5OvitQU6qEWxHTpWEeY,2490
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
9
9
  nshtrainer/callbacks/__init__.py,sha256=m6eJuprZfBELuKpngKXre33B9yPXkG7jlKVmI-0yXRQ,4000
@@ -15,7 +15,7 @@ nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=aCs3E1eucfDlUeW2Iq_Ke7
15
15
  nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
17
17
  nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
18
- nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
18
+ nshtrainer/callbacks/directory_setup.py,sha256=Ln6f0tCgoBscHeigIAWtCCoAmuWB-kPyaf7SylU7MYo,2773
19
19
  nshtrainer/callbacks/distributed_prediction_writer.py,sha256=OSh2C6XF7Nki4eFByNVhwlt69izkxnlmfPx54w4rvBo,5274
20
20
  nshtrainer/callbacks/early_stopping.py,sha256=rC_qYKCQWjRQJFo0ky46uG0aDJdYP8vsSlKunk0bUVI,4765
21
21
  nshtrainer/callbacks/ema.py,sha256=dBFiUXG0xmyCw8-ayuSzJMKqSbepl6Ii5VIbhFlT5ug,12255
@@ -135,7 +135,7 @@ nshtrainer/profiler/advanced.py,sha256=XrM3FX0ThCv5UwUrrH0l4Ow4LGAtpiBww2N8QAU5N
135
135
  nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1Q,2757
136
136
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
137
137
  nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
138
- nshtrainer/trainer/_config.py,sha256=tdWAYh-KGXBpgdY8fwvOejjRZN-AS2Ze0f_9s2VEuZ0,33556
138
+ nshtrainer/trainer/_config.py,sha256=EaAZCajuvDoVAn79zGebUgg0ijz-i3eLhwxYq8oTNe8,33600
139
139
  nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
140
140
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
141
141
  nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
@@ -159,6 +159,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
159
159
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
160
160
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
161
161
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
162
- nshtrainer-1.2.0.dist-info/METADATA,sha256=HkNLruaJJuf3ijnGe7NqNd9emBR6QHMRh2-taC5wTrU,960
163
- nshtrainer-1.2.0.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
- nshtrainer-1.2.0.dist-info/RECORD,,
162
+ nshtrainer-1.2.1.dist-info/METADATA,sha256=MrF6xvpgRjy2HsFwCIU9YZXMYcOb9ItAXY8m66P2CoQ,960
163
+ nshtrainer-1.2.1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
164
+ nshtrainer-1.2.1.dist-info/RECORD,,