nshtrainer 1.0.0b37__py3-none-any.whl → 1.0.0b39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/_directory.py +1 -1
- nshtrainer/callbacks/checkpoint/last_checkpoint.py +13 -12
- nshtrainer/loggers/base.py +9 -0
- {nshtrainer-1.0.0b37.dist-info → nshtrainer-1.0.0b39.dist-info}/METADATA +1 -1
- {nshtrainer-1.0.0b37.dist-info → nshtrainer-1.0.0b39.dist-info}/RECORD +6 -6
- {nshtrainer-1.0.0b37.dist-info → nshtrainer-1.0.0b39.dist-info}/WHEEL +0 -0
nshtrainer/_directory.py
CHANGED
@@ -81,7 +81,7 @@ class DirectoryConfig(C.Config):
|
|
81
81
|
|
82
82
|
# Save to nshtrainer/{id}/log/{logger name}
|
83
83
|
log_dir = self.resolve_subdirectory(run_id, "log")
|
84
|
-
log_dir = log_dir /
|
84
|
+
log_dir = log_dir / logger.resolve_logger_dirname()
|
85
85
|
# ^ NOTE: Logger must have a `name` attribute, as this is
|
86
86
|
# the discriminator for the logger registry
|
87
87
|
log_dir.mkdir(exist_ok=True)
|
@@ -21,11 +21,8 @@ log = logging.getLogger(__name__)
|
|
21
21
|
class LastCheckpointCallbackConfig(BaseCheckpointCallbackConfig):
|
22
22
|
name: Literal["last_checkpoint"] = "last_checkpoint"
|
23
23
|
|
24
|
-
save_on_time_interval:
|
25
|
-
"""
|
26
|
-
|
27
|
-
interval: timedelta = timedelta(hours=12)
|
28
|
-
"""Time interval between checkpoints when save_on_time_interval is True."""
|
24
|
+
save_on_time_interval: timedelta | None = None
|
25
|
+
"""Save a checkpoint every `save_on_time_interval` seconds. If `None`, this feature is disabled."""
|
29
26
|
|
30
27
|
@override
|
31
28
|
def create_checkpoint(self, trainer_config, dirpath):
|
@@ -38,8 +35,6 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
38
35
|
super().__init__(config, dirpath)
|
39
36
|
self.start_time = time.time()
|
40
37
|
self.last_checkpoint_time = self.start_time
|
41
|
-
self.interval_seconds = config.interval.total_seconds()
|
42
|
-
self.save_on_time_interval = config.save_on_time_interval
|
43
38
|
|
44
39
|
@override
|
45
40
|
def name(self):
|
@@ -57,12 +52,18 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
57
52
|
def topk_sort_reverse(self):
|
58
53
|
return True
|
59
54
|
|
60
|
-
def
|
61
|
-
if
|
55
|
+
def _local_should_checkpoint(self) -> bool:
|
56
|
+
if (interval := self.config.save_on_time_interval) is None:
|
62
57
|
return False
|
58
|
+
|
63
59
|
current_time = time.time()
|
64
60
|
elapsed_time = current_time - self.last_checkpoint_time
|
65
|
-
return elapsed_time >=
|
61
|
+
return elapsed_time >= interval.total_seconds()
|
62
|
+
|
63
|
+
def _should_checkpoint(self, trainer: Trainer):
|
64
|
+
if self.config.save_on_time_interval is None:
|
65
|
+
return False
|
66
|
+
return trainer.strategy.broadcast(self._local_should_checkpoint(), src=0)
|
66
67
|
|
67
68
|
def _format_duration(self, seconds: float) -> str:
|
68
69
|
"""Format duration in seconds to a human-readable string."""
|
@@ -98,7 +99,7 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
98
99
|
*args,
|
99
100
|
**kwargs,
|
100
101
|
):
|
101
|
-
if not self._should_checkpoint():
|
102
|
+
if not self._should_checkpoint(trainer):
|
102
103
|
return
|
103
104
|
self.save_checkpoints(trainer)
|
104
105
|
|
@@ -110,5 +111,5 @@ class LastCheckpointCallback(CheckpointBase[LastCheckpointCallbackConfig]):
|
|
110
111
|
def save_checkpoints(self, trainer):
|
111
112
|
super().save_checkpoints(trainer)
|
112
113
|
|
113
|
-
if self.save_on_time_interval:
|
114
|
+
if self.config.save_on_time_interval is not None:
|
114
115
|
self.last_checkpoint_time = time.time()
|
nshtrainer/loggers/base.py
CHANGED
@@ -30,5 +30,14 @@ class LoggerConfigBase(C.Config, ABC):
|
|
30
30
|
def __bool__(self):
|
31
31
|
return self.enabled
|
32
32
|
|
33
|
+
def resolve_logger_dirname(self) -> str:
|
34
|
+
if not (name := getattr(self, "name", None)):
|
35
|
+
raise ValueError(
|
36
|
+
"Logger must have a name attribute to resolve the directory name.\n"
|
37
|
+
"Otherwise, you must override `resolve_logger_dirname`."
|
38
|
+
)
|
39
|
+
|
40
|
+
return name
|
41
|
+
|
33
42
|
|
34
43
|
logger_registry = C.Registry(LoggerConfigBase, discriminator="name")
|
@@ -3,7 +3,7 @@ nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
|
|
3
3
|
nshtrainer/_callback.py,sha256=tXQCDzS6CvMTuTY5lQSH5qZs1pXUi-gt9bQdpXMVdEs,12715
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=PHy-54Cg-o3OtCffAqrVv6ZVMU7zhRo_-sZiSEEno1Y,5019
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=LOP8jjKF0Dw9x9H-BKrLMWlEp1XTan2DUK0zQUCWw5U,1360
|
6
|
-
nshtrainer/_directory.py,sha256=
|
6
|
+
nshtrainer/_directory.py,sha256=TJR9ccyuzRlAVfVjGyeQ3E2AFAcz-XbBCxWfiXo2SlY,3191
|
7
7
|
nshtrainer/_experimental/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
8
8
|
nshtrainer/_hf_hub.py,sha256=4OsCbIITnZk_YLyoMrVyZ0SIN04FBxlC0ig2Et8UAdo,14287
|
9
9
|
nshtrainer/callbacks/__init__.py,sha256=4giOYT8A709UOLRtQEt16QbOAFUHCjJ_aLB7ITTwXJI,3577
|
@@ -12,7 +12,7 @@ nshtrainer/callbacks/base.py,sha256=Alaou1IHAIlMEM7g58d_02ozY2xWlshBN7fsw5Ee21s,
|
|
12
12
|
nshtrainer/callbacks/checkpoint/__init__.py,sha256=l8tkHc83_mLiU0-wT09SWdRzwpm2ulbkLzcuCmuTwzE,620
|
13
13
|
nshtrainer/callbacks/checkpoint/_base.py,sha256=ZVEUVl5kjCSSe69Q0rMUbKBNNUog0pxBwWkeyuxG2w0,6304
|
14
14
|
nshtrainer/callbacks/checkpoint/best_checkpoint.py,sha256=2CQuhPJ3Fi7lDw7z-J8kXXXuDU8-4HcU48oZxR49apk,2667
|
15
|
-
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=
|
15
|
+
nshtrainer/callbacks/checkpoint/last_checkpoint.py,sha256=vn-as3ex7kaTRcKsIurVtM6kUSHYNwHJeYG82j2dMcc,3554
|
16
16
|
nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py,sha256=nljzETqkHwA-4g8mxaeFK5HxA8My0dlIPzIUscSMWyk,3525
|
17
17
|
nshtrainer/callbacks/debug_flag.py,sha256=96fuP0C7C6dSs1GiMeUYzzs0X3Q4Pjt9JVWg3b75fU4,1748
|
18
18
|
nshtrainer/callbacks/directory_setup.py,sha256=wPas_Ren8ANejogmIdKhqqgj4ulxz9AS_8xVIAfRXa0,2565
|
@@ -101,7 +101,7 @@ nshtrainer/data/datamodule.py,sha256=lSOgH32nysJWa6Y7ba1QyOdUV0DVVdO98qokP8wigjk
|
|
101
101
|
nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
|
102
102
|
nshtrainer/loggers/__init__.py,sha256=Ddd3JJXVzew_ZpwHA9kGnGmvq4OwhItwghDL5PzNhDc,614
|
103
103
|
nshtrainer/loggers/actsave.py,sha256=wgNrpBB6wQM7qff8iLDb_sQnbiAcYHRmH56pcEJPB3o,1409
|
104
|
-
nshtrainer/loggers/base.py,sha256=
|
104
|
+
nshtrainer/loggers/base.py,sha256=ON92XbwTSgadQOSyw5PiRRFzyH6uJ-xLtE0nB3cbgPc,1205
|
105
105
|
nshtrainer/loggers/csv.py,sha256=xJ8mSmw4vJwinIfqhF6t2HWmh_1dXEYyLfGuXwL7WHo,1160
|
106
106
|
nshtrainer/loggers/tensorboard.py,sha256=E7iO_fDt9bfH02hBL430bXPLljOo5iGgq2QyPqmx2gQ,2324
|
107
107
|
nshtrainer/loggers/wandb.py,sha256=KZXAUWrrmdX_L8rqej77oUHaM0JxZRM8y9z6JP9PISw,6856
|
@@ -151,6 +151,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
151
151
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
152
152
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
153
153
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
154
|
-
nshtrainer-1.0.
|
155
|
-
nshtrainer-1.0.
|
156
|
-
nshtrainer-1.0.
|
154
|
+
nshtrainer-1.0.0b39.dist-info/METADATA,sha256=zzE6nHlj-clB3HJs5_-bBePCHSOrtTkZTi9z_NrSeRY,988
|
155
|
+
nshtrainer-1.0.0b39.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
156
|
+
nshtrainer-1.0.0b39.dist-info/RECORD,,
|
File without changes
|