nshtrainer 1.0.0b37__py3-none-any.whl → 1.0.0b39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/_directory.py CHANGED
@@ -81,7 +81,7 @@ class DirectoryConfig(C.Config):
81
81
 
82
82
  # Save to nshtrainer/{id}/log/{logger name}
83
83
  log_dir = self.resolve_subdirectory(run_id, "log")
84
- log_dir = log_dir / getattr(logger, "name")
84
+ log_dir = log_dir / logger.resolve_logger_dirname()
85
85
  # ^ NOTE: Logger must have a `name` attribute, as this is
86
86
  # the discriminator for the logger registry
87
87
  log_dir.mkdir(exist_ok=True)
@@ -21,11 +21,8 @@ log = logging.getLogger(__name__)
21
21
  class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
22
22
  name: Literal["last_checkpoint"] = "last_checkpoint"
23
23
 
24
- save_on_time_interval: bool = True
25
- """Whether to save checkpoints based on time interval."""
26
-
27
- interval: timedelta = timedelta(hours=12)
28
- """Time interval between checkpoints when save_on_time_interval is True."""
24
+ save_on_time_interval: timedelta | None = None
25
+ """Save a checkpoint every `save_on_time_interval` seconds. If `None`, this feature is disabled."""
29
26
 
30
27
  @override
31
28
  def create_checkpoint(self, trainer_config, dirpath):
@@ -38,8 +35,6 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
38
35
  super().__init__(config, dirpath)
39
36
  self.start_time = time.time()
40
37
  self.last_checkpoint_time = self.start_time
41
- self.interval_seconds = config.interval.total_seconds()
42
- self.save_on_time_interval = config.save_on_time_interval
43
38
 
44
39
  @override
45
40
  def name(self):
@@ -57,12 +52,18 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
57
52
  def topk_sort_reverse(self):
58
53
  return True
59
54
 
60
- def _should_checkpoint(self) -> bool:
61
- if not self.save_on_time_interval:
55
+ def _local_should_checkpoint(self) -> bool:
56
+ if (interval := self.config.save_on_time_interval) is None:
62
57
  return False
58
+
63
59
  current_time = time.time()
64
60
  elapsed_time = current_time - self.last_checkpoint_time
65
- return elapsed_time >= self.interval_seconds
61
+ return elapsed_time >= interval.total_seconds()
62
+
63
+ def _should_checkpoint(self, trainer: Trainer):
64
+ if self.config.save_on_time_interval is None:
65
+ return False
66
+ return trainer.strategy.broadcast(self._local_should_checkpoint(), src=0)
66
67
 
67
68
  def _format_duration(self, seconds: float) -> str:
68
69
  """Format duration in seconds to a human-readable string."""
@@ -98,7 +99,7 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
98
99
  *args,
99
100
  **kwargs,
100
101
  ):
101
- if not self._should_checkpoint():
102
+ if not self._should_checkpoint(trainer):
102
103
  return
103
104
  self.save_checkpoints(trainer)
104
105
 
@@ -110,5 +111,5 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
110
111
  def save_checkpoints(self, trainer):
111
112
  super().save_checkpoints(trainer)
112
113
 
113
- if self.save_on_time_interval:
114
+ if self.config.save_on_time_interval is not None:
114
115
  self.last_checkpoint_time = time.time()
@@ -30,5 +30,14 @@ class LoggerConfigBase(C.Config, ABC):
30
30
  def __bool__(self):
31
31
  return self.enabled
32
32
 
33
+ def resolve_logger_dirname(self) -> str:
34
+ if not (name := getattr(self, "name", None)):
35
+ raise ValueError(
36
+ "Logger must have a name attribute to resolve the directory name.\n"
37
+ "Otherwise, you must override `resolve_logger_dirname`."
38
+ )
39
+
40
+ return name
41
+
33
42
 
34
43
  logger_registry = C.Registry(LoggerConfigBase, discriminator="name")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 1.0.0b37
3
+ Version: 1.0.0b39
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -3,7 +3,7 @@ nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
3
3
  nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
5
5
  nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
6
- nshtrainer/_directory.py,sha256=xY8Z9POZJw0Uh56yqffZbnNZvdA_tnWCucT31dhwFCM,3183
6
+ nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
7
7
  nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
8
8
  nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
9
9
  nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTwXJI,3577
@@ -12,7 +12,7 @@ nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,
12
12
  nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
13
13
  nshtrainer/callbacks/checkpoint/_base.py,sha256=ZVEUVl5kjCSSe69Q0rMUbKBNNUog0pxBwWkeyuxG2w0,6304
14
14
  nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
15
- nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=MJcNB0biOebx2si2IBFaSUiVOSLSCZTzxB-RcEgO2gY,3482
15
+ nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
16
16
  nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
17
17
  nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
18
18
  nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
@@ -101,7 +101,7 @@ nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk
101
101
  nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
102
102
  nshtrainer/loggers/__init__.py,sha256=Ddd3JJXVzew_ZpwHA9kGnGmvq4OwhItwghDL5PzNhDc,614
103
103
  nshtrainer/loggers/actsave.py,sha256=wgNrpBB6wQM7qff8iLDb_sQnbiAcYHRmH56pcEJPB3o,1409
104
- nshtrainer/loggers/base.py,sha256=1-HoPmOiyXevQvMLXboiKe-4GOE1V5SvjURohOHakVc,882
104
+ nshtrainer/loggers/base.py,sha256=ON92XbwTSgadQOSyw5PiRRFzyH6uJ-xLtE0nB3cbgPc,1205
105
105
  nshtrainer/loggers/csv.py,sha256=xJ8mSmw4vJwinIfqhF6t2HWmh_1dXEYyLfGuXwL7WHo,1160
106
106
  nshtrainer/loggers/tensorboard.py,sha256=E7iO_fDt9bfH02hBL430bXPLljOo5iGgq2QyPqmx2gQ,2324
107
107
  nshtrainer/loggers/wandb.py,sha256=KZXAUWrrmdX_L8rqej77oUHaM0JxZRM8y9z6JP9PISw,6856
@@ -151,6 +151,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
151
151
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
152
152
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
153
153
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
154
- nshtrainer-1.0.0b37.dist-info/METADATA,sha256=ObMgpZ_qJLmBAkeRDN7ufTuRSTltiB_LYPFTphNvWks,988
155
- nshtrainer-1.0.0b37.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
156
- nshtrainer-1.0.0b37.dist-info/RECORD,,
154
+ nshtrainer-1.0.0b39.dist-info/METADATA,sha256=zzE6nHlj-clB3HJs5_-bBePCHSOrtTkZTi9z_NrSeRY,988
155
+ nshtrainer-1.0.0b39.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
156
+ nshtrainer-1.0.0b39.dist-info/RECORD,,