nshtrainer 0.10.12__tar.gz → 0.10.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/PKG-INFO +1 -1
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/pyproject.toml +1 -1
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/_checkpoint/saver.py +13 -4
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +1 -1
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/model_checkpoint.py +1 -1
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/wandb_watch.py +24 -24
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/config.py +36 -6
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/README.md +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/callback.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/debug.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/distributed.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/logger.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/profiler.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/trainer/trainer.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -36,7 +36,8 @@ def _link_checkpoint(
|
|
|
36
36
|
# fall back to copying the file
|
|
37
37
|
shutil.copy(filepath, linkpath)
|
|
38
38
|
|
|
39
|
-
|
|
39
|
+
if metadata:
|
|
40
|
+
_link_checkpoint_metadata(filepath, linkpath)
|
|
40
41
|
if barrier:
|
|
41
42
|
trainer.strategy.barrier()
|
|
42
43
|
|
|
@@ -44,9 +45,17 @@ def _link_checkpoint(
|
|
|
44
45
|
def _remove_checkpoint(
|
|
45
46
|
trainer: Trainer,
|
|
46
47
|
filepath: str | Path | os.PathLike,
|
|
47
|
-
|
|
48
|
+
*,
|
|
49
|
+
metadata: bool,
|
|
50
|
+
barrier: bool,
|
|
48
51
|
):
|
|
49
52
|
if not isinstance(filepath, Path):
|
|
50
53
|
filepath = Path(filepath)
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
|
|
55
|
+
if trainer.is_global_zero:
|
|
56
|
+
trainer.strategy.remove_checkpoint(filepath)
|
|
57
|
+
if metadata:
|
|
58
|
+
_remove_checkpoint_metadata(filepath)
|
|
59
|
+
|
|
60
|
+
if barrier:
|
|
61
|
+
trainer.strategy.barrier()
|
{nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py
RENAMED
|
@@ -69,7 +69,7 @@ class LatestEpochCheckpoint(Checkpoint):
|
|
|
69
69
|
|
|
70
70
|
def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
|
|
71
71
|
for ckpt_path in ckpt_paths:
|
|
72
|
-
_remove_checkpoint(trainer, ckpt_path,
|
|
72
|
+
_remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
|
|
73
73
|
|
|
74
74
|
def _remove_old_checkpoints(self, trainer: Trainer):
|
|
75
75
|
if (latest_k := self.config.latest_k) == "all":
|
|
@@ -202,4 +202,4 @@ class ModelCheckpoint(_ModelCheckpoint):
|
|
|
202
202
|
|
|
203
203
|
@override
|
|
204
204
|
def _remove_checkpoint(self, trainer: Trainer, filepath: str):
|
|
205
|
-
return _remove_checkpoint(trainer, filepath,
|
|
205
|
+
return _remove_checkpoint(trainer, filepath, metadata=True, barrier=False)
|
|
@@ -12,13 +12,36 @@ from .base import CallbackConfigBase
|
|
|
12
12
|
log = logging.getLogger(__name__)
|
|
13
13
|
|
|
14
14
|
|
|
15
|
+
class WandbWatchConfig(CallbackConfigBase):
|
|
16
|
+
name: Literal["wandb_watch"] = "wandb_watch"
|
|
17
|
+
|
|
18
|
+
enabled: bool = True
|
|
19
|
+
"""Enable watching the model for wandb."""
|
|
20
|
+
|
|
21
|
+
log: str | None = None
|
|
22
|
+
"""Log type for wandb."""
|
|
23
|
+
|
|
24
|
+
log_graph: bool = True
|
|
25
|
+
"""Whether to log the graph for wandb."""
|
|
26
|
+
|
|
27
|
+
log_freq: int = 100
|
|
28
|
+
"""Log frequency for wandb."""
|
|
29
|
+
|
|
30
|
+
def __bool__(self):
|
|
31
|
+
return self.enabled
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def create_callbacks(self, root_config):
|
|
35
|
+
yield WandbWatchCallback(self)
|
|
36
|
+
|
|
37
|
+
|
|
15
38
|
@runtime_checkable
|
|
16
39
|
class _HasWandbLogModuleProtocol(Protocol):
|
|
17
40
|
def wandb_log_module(self) -> nn.Module | None: ...
|
|
18
41
|
|
|
19
42
|
|
|
20
43
|
class WandbWatchCallback(Callback):
|
|
21
|
-
def __init__(self, config:
|
|
44
|
+
def __init__(self, config: WandbWatchConfig):
|
|
22
45
|
super().__init__()
|
|
23
46
|
|
|
24
47
|
self.config = config
|
|
@@ -78,26 +101,3 @@ class WandbWatchCallback(Callback):
|
|
|
78
101
|
log_graph=self.config.log_graph,
|
|
79
102
|
)
|
|
80
103
|
setattr(pl_module, "_model_watched", True)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class WandbWatchConfig(CallbackConfigBase):
|
|
84
|
-
name: Literal["wandb_watch"] = "wandb_watch"
|
|
85
|
-
|
|
86
|
-
enabled: bool = True
|
|
87
|
-
"""Enable watching the model for wandb."""
|
|
88
|
-
|
|
89
|
-
log: str | None = None
|
|
90
|
-
"""Log type for wandb."""
|
|
91
|
-
|
|
92
|
-
log_graph: bool = True
|
|
93
|
-
"""Whether to log the graph for wandb."""
|
|
94
|
-
|
|
95
|
-
log_freq: int = 100
|
|
96
|
-
"""Log frequency for wandb."""
|
|
97
|
-
|
|
98
|
-
def __bool__(self):
|
|
99
|
-
return self.enabled
|
|
100
|
-
|
|
101
|
-
@override
|
|
102
|
-
def create_callbacks(self, root_config):
|
|
103
|
-
yield WandbWatchCallback(self)
|
|
@@ -20,6 +20,7 @@ from typing import (
|
|
|
20
20
|
|
|
21
21
|
import nshconfig as C
|
|
22
22
|
import numpy as np
|
|
23
|
+
import pkg_resources
|
|
23
24
|
import torch
|
|
24
25
|
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
|
25
26
|
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT
|
|
@@ -213,7 +214,7 @@ class BaseLoggerConfig(C.Config, ABC):
|
|
|
213
214
|
"""Enable this logger."""
|
|
214
215
|
|
|
215
216
|
priority: int = 0
|
|
216
|
-
"""Priority of the logger. Higher
|
|
217
|
+
"""Priority of the logger. Higher priority loggers are created first."""
|
|
217
218
|
|
|
218
219
|
log_dir: DirectoryPath | None = None
|
|
219
220
|
"""Directory to save the logs to. If None, will use the default log directory for the trainer."""
|
|
@@ -266,7 +267,8 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
266
267
|
"""Enable WandB logging."""
|
|
267
268
|
|
|
268
269
|
priority: int = 2
|
|
269
|
-
"""Priority of the logger. Higher
|
|
270
|
+
"""Priority of the logger. Higher priority loggers are created first,
|
|
271
|
+
and the highest priority logger is the "main" logger for PyTorch Lightning."""
|
|
270
272
|
|
|
271
273
|
project: str | None = None
|
|
272
274
|
"""WandB project name to use for the logger. If None, will use the root config's project name."""
|
|
@@ -286,8 +288,17 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
286
288
|
offline: bool = False
|
|
287
289
|
"""Whether to run WandB in offline mode."""
|
|
288
290
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
+
use_wandb_core: bool = False
|
|
292
|
+
"""Whether to use the new `wandb-core` backend for WandB.
|
|
293
|
+
`wandb-core` is a new backend for WandB that is faster and more efficient than the old backend.
|
|
294
|
+
"""
|
|
295
|
+
|
|
296
|
+
def offline_(self, value: bool = True):
|
|
297
|
+
self.offline = value
|
|
298
|
+
return self
|
|
299
|
+
|
|
300
|
+
def core_(self, value: bool = True):
|
|
301
|
+
self.use_wandb_core = value
|
|
291
302
|
return self
|
|
292
303
|
|
|
293
304
|
@override
|
|
@@ -295,6 +306,25 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
|
295
306
|
if not self.enabled:
|
|
296
307
|
return None
|
|
297
308
|
|
|
309
|
+
# If `wandb-core` is enabled, we should use the new backend.
|
|
310
|
+
if self.use_wandb_core:
|
|
311
|
+
try:
|
|
312
|
+
import wandb # type: ignore
|
|
313
|
+
|
|
314
|
+
# The minimum version that supports the new backend is 0.17.5
|
|
315
|
+
if pkg_resources.parse_version(
|
|
316
|
+
wandb.__version__
|
|
317
|
+
) < pkg_resources.parse_version("0.17.5"):
|
|
318
|
+
log.warning(
|
|
319
|
+
"The version of WandB installed does not support the `wandb-core` backend. "
|
|
320
|
+
"Unable to use the `wandb-core` backend for WandB."
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
wandb.require("core")
|
|
324
|
+
log.critical("Using the `wandb-core` backend for WandB.")
|
|
325
|
+
except ImportError:
|
|
326
|
+
pass
|
|
327
|
+
|
|
298
328
|
from lightning.pytorch.loggers.wandb import WandbLogger
|
|
299
329
|
|
|
300
330
|
save_dir = root_config.directory._resolve_log_directory_for_logger(
|
|
@@ -329,7 +359,7 @@ class CSVLoggerConfig(BaseLoggerConfig):
|
|
|
329
359
|
"""Enable CSV logging."""
|
|
330
360
|
|
|
331
361
|
priority: int = 0
|
|
332
|
-
"""Priority of the logger. Higher
|
|
362
|
+
"""Priority of the logger. Higher priority loggers are created first."""
|
|
333
363
|
|
|
334
364
|
prefix: str = ""
|
|
335
365
|
"""A string to put at the beginning of metric keys."""
|
|
@@ -383,7 +413,7 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
|
383
413
|
"""Enable TensorBoard logging."""
|
|
384
414
|
|
|
385
415
|
priority: int = 2
|
|
386
|
-
"""Priority of the logger. Higher
|
|
416
|
+
"""Priority of the logger. Higher priority loggers are created first."""
|
|
387
417
|
|
|
388
418
|
log_graph: bool = False
|
|
389
419
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/flops/flop_counter.py
RENAMED
|
File without changes
|
{nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/_experimental/flops/module_tracker.py
RENAMED
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/callbacks/on_exception_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py
RENAMED
|
File without changes
|
{nshtrainer-0.10.12 → nshtrainer-0.10.14}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|