nshtrainer 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +1 -17
- nshtrainer/callbacks/__init__.py +3 -2
- nshtrainer/callbacks/base.py +3 -4
- nshtrainer/config.py +3 -288
- nshtrainer/lr_scheduler/__init__.py +3 -2
- nshtrainer/lr_scheduler/_base.py +3 -6
- nshtrainer/lr_scheduler/linear_warmup_cosine.py +5 -5
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +5 -4
- nshtrainer/model/__init__.py +0 -4
- nshtrainer/model/base.py +9 -71
- nshtrainer/model/config.py +39 -141
- nshtrainer/nn/nonlinearity.py +3 -4
- nshtrainer/optimizer.py +3 -7
- nshtrainer/runner.py +18 -8
- nshtrainer/trainer/signal_connector.py +22 -11
- nshtrainer/trainer/trainer.py +1 -1
- nshtrainer/typecheck.py +1 -0
- {nshtrainer-0.1.0.dist-info → nshtrainer-0.2.0.dist-info}/METADATA +13 -2
- {nshtrainer-0.1.0.dist-info → nshtrainer-0.2.0.dist-info}/RECORD +20 -27
- nshtrainer/_submit/print_environment_info.py +0 -31
- nshtrainer/_submit/session/_output.py +0 -12
- nshtrainer/_submit/session/_script.py +0 -109
- nshtrainer/_submit/session/lsf.py +0 -467
- nshtrainer/_submit/session/slurm.py +0 -573
- nshtrainer/_submit/session/unified.py +0 -350
- nshtrainer/util/singleton.py +0 -89
- {nshtrainer-0.1.0.dist-info → nshtrainer-0.2.0.dist-info}/WHEEL +0 -0
nshtrainer/model/config.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import os
|
|
3
3
|
import re
|
|
4
|
-
import signal
|
|
5
4
|
import socket
|
|
6
5
|
import string
|
|
7
6
|
import time
|
|
@@ -21,6 +20,7 @@ from typing import (
|
|
|
21
20
|
runtime_checkable,
|
|
22
21
|
)
|
|
23
22
|
|
|
23
|
+
import nshconfig as C
|
|
24
24
|
import numpy as np
|
|
25
25
|
import torch
|
|
26
26
|
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
|
@@ -39,7 +39,6 @@ from typing_extensions import Self, TypedDict, TypeVar, override
|
|
|
39
39
|
from ..callbacks import CallbackConfig
|
|
40
40
|
from ..callbacks.base import CallbackConfigBase
|
|
41
41
|
from ..callbacks.wandb_watch import WandbWatchConfig
|
|
42
|
-
from ..config import Field, TypedConfig
|
|
43
42
|
from ..util.slurm import parse_slurm_node_list
|
|
44
43
|
|
|
45
44
|
log = getLogger(__name__)
|
|
@@ -49,7 +48,7 @@ class IdSeedWarning(Warning):
|
|
|
49
48
|
pass
|
|
50
49
|
|
|
51
50
|
|
|
52
|
-
class BaseProfilerConfig(
|
|
51
|
+
class BaseProfilerConfig(C.Config, ABC):
|
|
53
52
|
dirpath: str | Path | None = None
|
|
54
53
|
"""
|
|
55
54
|
Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
|
|
@@ -200,11 +199,11 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
|
|
|
200
199
|
|
|
201
200
|
ProfilerConfig: TypeAlias = Annotated[
|
|
202
201
|
SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
|
|
203
|
-
Field(discriminator="kind"),
|
|
202
|
+
C.Field(discriminator="kind"),
|
|
204
203
|
]
|
|
205
204
|
|
|
206
205
|
|
|
207
|
-
class EnvironmentClassInformationConfig(
|
|
206
|
+
class EnvironmentClassInformationConfig(C.Config):
|
|
208
207
|
name: str
|
|
209
208
|
module: str
|
|
210
209
|
full_name: str
|
|
@@ -213,7 +212,7 @@ class EnvironmentClassInformationConfig(TypedConfig):
|
|
|
213
212
|
source_file_path: Path | None = None
|
|
214
213
|
|
|
215
214
|
|
|
216
|
-
class EnvironmentSLURMInformationConfig(
|
|
215
|
+
class EnvironmentSLURMInformationConfig(C.Config):
|
|
217
216
|
hostname: str
|
|
218
217
|
hostnames: list[str]
|
|
219
218
|
job_id: str
|
|
@@ -271,7 +270,7 @@ class EnvironmentSLURMInformationConfig(TypedConfig):
|
|
|
271
270
|
return None
|
|
272
271
|
|
|
273
272
|
|
|
274
|
-
class EnvironmentLSFInformationConfig(
|
|
273
|
+
class EnvironmentLSFInformationConfig(C.Config):
|
|
275
274
|
hostname: str
|
|
276
275
|
hostnames: list[str]
|
|
277
276
|
job_id: str
|
|
@@ -328,7 +327,7 @@ class EnvironmentLSFInformationConfig(TypedConfig):
|
|
|
328
327
|
return None
|
|
329
328
|
|
|
330
329
|
|
|
331
|
-
class EnvironmentLinuxEnvironmentConfig(
|
|
330
|
+
class EnvironmentLinuxEnvironmentConfig(C.Config):
|
|
332
331
|
"""
|
|
333
332
|
Information about the Linux environment (e.g., current user, hostname, etc.)
|
|
334
333
|
"""
|
|
@@ -347,9 +346,25 @@ class EnvironmentLinuxEnvironmentConfig(TypedConfig):
|
|
|
347
346
|
load_avg: tuple[float, float, float] | None = None
|
|
348
347
|
|
|
349
348
|
|
|
350
|
-
class
|
|
349
|
+
class EnvironmentSnapshotConfig(C.Config):
|
|
350
|
+
snapshot_dir: Path | None = None
|
|
351
|
+
modules: list[str] | None = None
|
|
352
|
+
|
|
353
|
+
@classmethod
|
|
354
|
+
def from_current_environment(cls):
|
|
355
|
+
draft = cls.draft()
|
|
356
|
+
if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
|
|
357
|
+
draft.snapshot_dir = Path(snapshot_dir)
|
|
358
|
+
if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
|
|
359
|
+
draft.modules = modules.split(",")
|
|
360
|
+
return draft.finalize()
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class EnvironmentConfig(C.Config):
|
|
351
364
|
cwd: Path | None = None
|
|
352
365
|
|
|
366
|
+
snapshot: EnvironmentSnapshotConfig | None = None
|
|
367
|
+
|
|
353
368
|
python_executable: Path | None = None
|
|
354
369
|
python_path: list[Path] | None = None
|
|
355
370
|
python_version: str | None = None
|
|
@@ -372,7 +387,7 @@ class EnvironmentConfig(TypedConfig):
|
|
|
372
387
|
seed_workers: bool | None = None
|
|
373
388
|
|
|
374
389
|
|
|
375
|
-
class BaseLoggerConfig(
|
|
390
|
+
class BaseLoggerConfig(C.Config, ABC):
|
|
376
391
|
enabled: bool = True
|
|
377
392
|
"""Enable this logger."""
|
|
378
393
|
|
|
@@ -426,7 +441,7 @@ def _wandb_available():
|
|
|
426
441
|
class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
|
|
427
442
|
kind: Literal["wandb"] = "wandb"
|
|
428
443
|
|
|
429
|
-
enabled: bool = Field(default_factory=lambda: _wandb_available())
|
|
444
|
+
enabled: bool = C.Field(default_factory=lambda: _wandb_available())
|
|
430
445
|
"""Enable WandB logging."""
|
|
431
446
|
|
|
432
447
|
priority: int = 2
|
|
@@ -543,7 +558,7 @@ def _tensorboard_available():
|
|
|
543
558
|
class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
544
559
|
kind: Literal["tensorboard"] = "tensorboard"
|
|
545
560
|
|
|
546
|
-
enabled: bool = Field(default_factory=lambda: _tensorboard_available())
|
|
561
|
+
enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
|
|
547
562
|
"""Enable TensorBoard logging."""
|
|
548
563
|
|
|
549
564
|
priority: int = 2
|
|
@@ -589,7 +604,7 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
|
|
|
589
604
|
|
|
590
605
|
LoggerConfig: TypeAlias = Annotated[
|
|
591
606
|
WandbLoggerConfig | CSVLoggerConfig | TensorboardLoggerConfig,
|
|
592
|
-
Field(discriminator="kind"),
|
|
607
|
+
C.Field(discriminator="kind"),
|
|
593
608
|
]
|
|
594
609
|
|
|
595
610
|
|
|
@@ -684,7 +699,7 @@ class LoggingConfig(CallbackConfigBase):
|
|
|
684
699
|
yield from logger.construct_callbacks(root_config)
|
|
685
700
|
|
|
686
701
|
|
|
687
|
-
class GradientClippingConfig(
|
|
702
|
+
class GradientClippingConfig(C.Config):
|
|
688
703
|
enabled: bool = True
|
|
689
704
|
"""Enable gradient clipping."""
|
|
690
705
|
value: int | float
|
|
@@ -719,41 +734,6 @@ class OptimizationConfig(CallbackConfigBase):
|
|
|
719
734
|
).construct_callbacks(root_config)
|
|
720
735
|
|
|
721
736
|
|
|
722
|
-
LogLevel: TypeAlias = Literal[
|
|
723
|
-
"CRITICAL", "FATAL", "ERROR", "WARN", "WARNING", "INFO", "DEBUG"
|
|
724
|
-
]
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
class PythonLogging(TypedConfig):
|
|
728
|
-
log_level: LogLevel | None = None
|
|
729
|
-
"""Log level to use for the Python logger (or None to use the default)."""
|
|
730
|
-
|
|
731
|
-
rich: bool = False
|
|
732
|
-
"""If enabled, will use the rich library to format the Python logger output."""
|
|
733
|
-
rich_tracebacks: bool = True
|
|
734
|
-
"""If enabled, will use the rich library to format the Python logger tracebacks."""
|
|
735
|
-
|
|
736
|
-
lovely_tensors: bool = False
|
|
737
|
-
"""If enabled, will use the lovely-tensors library to format PyTorch tensors. False by default as it causes issues when used with `torch.vmap`."""
|
|
738
|
-
lovely_numpy: bool = False
|
|
739
|
-
"""If enabled, will use the lovely-numpy library to format numpy arrays. False by default as it causes some issues with other libaries."""
|
|
740
|
-
|
|
741
|
-
def pretty_(
|
|
742
|
-
self,
|
|
743
|
-
*,
|
|
744
|
-
log_level: LogLevel | None = "INFO",
|
|
745
|
-
torch: bool = True,
|
|
746
|
-
numpy: bool = True,
|
|
747
|
-
rich: bool = True,
|
|
748
|
-
rich_tracebacks: bool = True,
|
|
749
|
-
):
|
|
750
|
-
self.log_level = log_level
|
|
751
|
-
self.lovely_tensors = torch
|
|
752
|
-
self.lovely_numpy = numpy
|
|
753
|
-
self.rich = rich
|
|
754
|
-
self.rich_tracebacks = rich_tracebacks
|
|
755
|
-
|
|
756
|
-
|
|
757
737
|
TPlugin = TypeVar(
|
|
758
738
|
"TPlugin",
|
|
759
739
|
Precision,
|
|
@@ -813,7 +793,7 @@ StrategyLiteral: TypeAlias = Literal[
|
|
|
813
793
|
]
|
|
814
794
|
|
|
815
795
|
|
|
816
|
-
class CheckpointLoadingConfig(
|
|
796
|
+
class CheckpointLoadingConfig(C.Config):
|
|
817
797
|
path: Literal["best", "last", "hpc"] | str | Path | None = None
|
|
818
798
|
"""
|
|
819
799
|
Checkpoint path to use when loading a checkpoint.
|
|
@@ -825,7 +805,7 @@ class CheckpointLoadingConfig(TypedConfig):
|
|
|
825
805
|
"""
|
|
826
806
|
|
|
827
807
|
|
|
828
|
-
class DirectoryConfig(
|
|
808
|
+
class DirectoryConfig(C.Config):
|
|
829
809
|
project_root: Path | None = None
|
|
830
810
|
"""
|
|
831
811
|
Root directory for this project.
|
|
@@ -901,7 +881,7 @@ class DirectoryConfig(TypedConfig):
|
|
|
901
881
|
return log_dir
|
|
902
882
|
|
|
903
883
|
|
|
904
|
-
class ReproducibilityConfig(
|
|
884
|
+
class ReproducibilityConfig(C.Config):
|
|
905
885
|
deterministic: bool | Literal["warn"] | None = None
|
|
906
886
|
"""
|
|
907
887
|
If ``True``, sets whether PyTorch operations must use deterministic algorithms.
|
|
@@ -1116,7 +1096,7 @@ CheckpointCallbackConfig: TypeAlias = Annotated[
|
|
|
1116
1096
|
ModelCheckpointCallbackConfig
|
|
1117
1097
|
| LatestEpochCheckpointCallbackConfig
|
|
1118
1098
|
| OnExceptionCheckpointCallbackConfig,
|
|
1119
|
-
Field(discriminator="kind"),
|
|
1099
|
+
C.Field(discriminator="kind"),
|
|
1120
1100
|
]
|
|
1121
1101
|
|
|
1122
1102
|
|
|
@@ -1514,7 +1494,7 @@ class ActSaveConfig(CallbackConfigBase):
|
|
|
1514
1494
|
return [ActSaveCallback()]
|
|
1515
1495
|
|
|
1516
1496
|
|
|
1517
|
-
class SanityCheckingConfig(
|
|
1497
|
+
class SanityCheckingConfig(C.Config):
|
|
1518
1498
|
reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
|
|
1519
1499
|
"""
|
|
1520
1500
|
If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
|
|
@@ -1524,7 +1504,7 @@ class SanityCheckingConfig(TypedConfig):
|
|
|
1524
1504
|
"""
|
|
1525
1505
|
|
|
1526
1506
|
|
|
1527
|
-
class TrainerConfig(
|
|
1507
|
+
class TrainerConfig(C.Config):
|
|
1528
1508
|
checkpoint_loading: CheckpointLoadingConfig = CheckpointLoadingConfig()
|
|
1529
1509
|
"""Checkpoint loading configuration options."""
|
|
1530
1510
|
|
|
@@ -1739,87 +1719,7 @@ class TrainerConfig(TypedConfig):
|
|
|
1739
1719
|
"""If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
|
|
1740
1720
|
|
|
1741
1721
|
|
|
1742
|
-
class
|
|
1743
|
-
seed: int
|
|
1744
|
-
"""Seed for the random number generator."""
|
|
1745
|
-
|
|
1746
|
-
seed_workers: bool = False
|
|
1747
|
-
"""Whether to seed the workers of the dataloader."""
|
|
1748
|
-
|
|
1749
|
-
|
|
1750
|
-
Signal: TypeAlias = Literal[
|
|
1751
|
-
"SIGHUP",
|
|
1752
|
-
"SIGINT",
|
|
1753
|
-
"SIGQUIT",
|
|
1754
|
-
"SIGILL",
|
|
1755
|
-
"SIGTRAP",
|
|
1756
|
-
"SIGABRT",
|
|
1757
|
-
"SIGBUS",
|
|
1758
|
-
"SIGFPE",
|
|
1759
|
-
"SIGKILL",
|
|
1760
|
-
"SIGUSR1",
|
|
1761
|
-
"SIGSEGV",
|
|
1762
|
-
"SIGUSR2",
|
|
1763
|
-
"SIGPIPE",
|
|
1764
|
-
"SIGALRM",
|
|
1765
|
-
"SIGTERM",
|
|
1766
|
-
"SIGCHLD",
|
|
1767
|
-
"SIGCONT",
|
|
1768
|
-
"SIGSTOP",
|
|
1769
|
-
"SIGTSTP",
|
|
1770
|
-
"SIGTTIN",
|
|
1771
|
-
"SIGTTOU",
|
|
1772
|
-
"SIGURG",
|
|
1773
|
-
"SIGXCPU",
|
|
1774
|
-
"SIGXFSZ",
|
|
1775
|
-
"SIGVTALRM",
|
|
1776
|
-
"SIGPROF",
|
|
1777
|
-
"SIGWINCH",
|
|
1778
|
-
"SIGIO",
|
|
1779
|
-
"SIGPWR",
|
|
1780
|
-
"SIGSYS",
|
|
1781
|
-
"SIGRTMIN",
|
|
1782
|
-
"SIGRTMAX",
|
|
1783
|
-
]
|
|
1784
|
-
|
|
1785
|
-
|
|
1786
|
-
class SubmitConfig(TypedConfig):
|
|
1787
|
-
auto_requeue_signals: list[Signal] = [
|
|
1788
|
-
# "SIGUSR1",
|
|
1789
|
-
# On SIGURG:
|
|
1790
|
-
# Important note from https://amrex-astro.github.io/workflow/olcf-workflow.html:
|
|
1791
|
-
# We can also ask the job manager to send a warning signal some amount of time before the allocation expires by passing -wa 'signal' and -wt '[hour:]minute' to bsub. We can then have bash create a dump_and_stop file when it receives the signal, which will tell Castro to output a checkpoint file and exit cleanly after it finishes the current timestep. An important detail that I couldn't find documented anywhere is that the job manager sends the signal to all the processes in the job, not just the submission script, and we have to use a signal that is ignored by default so Castro doesn't immediately crash upon receiving it. SIGCHLD, SIGURG, and SIGWINCH are the only signals that fit this requirement and of these, SIGURG is the least likely to be triggered by other events.
|
|
1792
|
-
"SIGURG"
|
|
1793
|
-
]
|
|
1794
|
-
"""Signals that will trigger an automatic requeue of the job."""
|
|
1795
|
-
|
|
1796
|
-
def _resolved_auto_requeue_signals(self) -> list[signal.Signals]:
|
|
1797
|
-
return [getattr(signal.Signals, sig) for sig in self.auto_requeue_signals]
|
|
1798
|
-
|
|
1799
|
-
|
|
1800
|
-
class RunnerConfig(TypedConfig):
|
|
1801
|
-
python_logging: PythonLogging = PythonLogging()
|
|
1802
|
-
"""Python logging configuration options."""
|
|
1803
|
-
|
|
1804
|
-
seed: SeedConfig = SeedConfig(seed=0)
|
|
1805
|
-
"""Seed everything configuration options."""
|
|
1806
|
-
|
|
1807
|
-
submit: SubmitConfig = SubmitConfig()
|
|
1808
|
-
"""Submit (e.g., SLURM or LSF) configuration options."""
|
|
1809
|
-
|
|
1810
|
-
dump_run_information: bool = True
|
|
1811
|
-
"""
|
|
1812
|
-
If enabled, will dump different bits of run information to the output directory before starting the run.
|
|
1813
|
-
This includes:
|
|
1814
|
-
- Run config
|
|
1815
|
-
- Full set of environment variables
|
|
1816
|
-
"""
|
|
1817
|
-
|
|
1818
|
-
additional_env_vars: dict[str, str] = {}
|
|
1819
|
-
"""Additional environment variables to set when running the script."""
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
class MetricConfig(TypedConfig):
|
|
1722
|
+
class MetricConfig(C.Config):
|
|
1823
1723
|
name: str
|
|
1824
1724
|
"""The name of the primary metric."""
|
|
1825
1725
|
|
|
@@ -1851,8 +1751,8 @@ class MetricConfig(TypedConfig):
|
|
|
1851
1751
|
PrimaryMetricConfig: TypeAlias = MetricConfig
|
|
1852
1752
|
|
|
1853
1753
|
|
|
1854
|
-
class BaseConfig(
|
|
1855
|
-
id: str = Field(default_factory=lambda: BaseConfig.generate_id())
|
|
1754
|
+
class BaseConfig(C.Config):
|
|
1755
|
+
id: str = C.Field(default_factory=lambda: BaseConfig.generate_id())
|
|
1856
1756
|
"""ID of the run."""
|
|
1857
1757
|
name: str | None = None
|
|
1858
1758
|
"""Run name."""
|
|
@@ -1867,15 +1767,13 @@ class BaseConfig(TypedConfig):
|
|
|
1867
1767
|
|
|
1868
1768
|
debug: bool = False
|
|
1869
1769
|
"""Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
|
|
1870
|
-
environment: Annotated[EnvironmentConfig, Field(repr=False)] = EnvironmentConfig()
|
|
1770
|
+
environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = EnvironmentConfig()
|
|
1871
1771
|
"""A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
|
|
1872
1772
|
|
|
1873
1773
|
directory: DirectoryConfig = DirectoryConfig()
|
|
1874
1774
|
"""Directory configuration options."""
|
|
1875
1775
|
trainer: TrainerConfig = TrainerConfig()
|
|
1876
1776
|
"""PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information."""
|
|
1877
|
-
runner: RunnerConfig = RunnerConfig()
|
|
1878
|
-
"""`ll.Runner` configuration options."""
|
|
1879
1777
|
|
|
1880
1778
|
primary_metric: PrimaryMetricConfig | None = None
|
|
1881
1779
|
"""Primary metric configuration options. This is used in the following ways:
|
nshtrainer/nn/nonlinearity.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from typing import Annotated, Literal
|
|
3
3
|
|
|
4
|
+
import nshconfig as C
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn as nn
|
|
6
7
|
from typing_extensions import override
|
|
7
8
|
|
|
8
|
-
from ..config import Field, TypedConfig
|
|
9
9
|
|
|
10
|
-
|
|
11
|
-
class BaseNonlinearityConfig(TypedConfig, ABC):
|
|
10
|
+
class BaseNonlinearityConfig(C.Config, ABC):
|
|
12
11
|
@abstractmethod
|
|
13
12
|
def create_module(self) -> nn.Module:
|
|
14
13
|
pass
|
|
@@ -153,5 +152,5 @@ NonlinearityConfig = Annotated[
|
|
|
153
152
|
| SiLUNonlinearityConfig
|
|
154
153
|
| MishNonlinearityConfig
|
|
155
154
|
| SwiGLUNonlinearityConfig,
|
|
156
|
-
Field(discriminator="name"),
|
|
155
|
+
C.Field(discriminator="name"),
|
|
157
156
|
]
|
nshtrainer/optimizer.py
CHANGED
|
@@ -2,14 +2,13 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from collections.abc import Iterable
|
|
3
3
|
from typing import Annotated, Any, Literal, TypeAlias
|
|
4
4
|
|
|
5
|
+
import nshconfig as C
|
|
5
6
|
import torch.nn as nn
|
|
6
7
|
from torch.optim import Optimizer
|
|
7
8
|
from typing_extensions import override
|
|
8
9
|
|
|
9
|
-
from .config import Field, TypedConfig
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
class OptimizerConfigBase(TypedConfig, ABC):
|
|
11
|
+
class OptimizerConfigBase(C.Config, ABC):
|
|
13
12
|
@abstractmethod
|
|
14
13
|
def create_optimizer(
|
|
15
14
|
self,
|
|
@@ -56,7 +55,4 @@ class AdamWConfig(OptimizerConfigBase):
|
|
|
56
55
|
)
|
|
57
56
|
|
|
58
57
|
|
|
59
|
-
OptimizerConfig: TypeAlias = Annotated[
|
|
60
|
-
AdamWConfig,
|
|
61
|
-
Field(discriminator="name"),
|
|
62
|
-
]
|
|
58
|
+
OptimizerConfig: TypeAlias = Annotated[AdamWConfig, C.Field(discriminator="name")]
|
nshtrainer/runner.py
CHANGED
|
@@ -1,21 +1,31 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
1
|
from typing import Generic
|
|
3
2
|
|
|
3
|
+
from nshrunner import RunInfo
|
|
4
4
|
from nshrunner import Runner as _Runner
|
|
5
|
-
from typing_extensions import
|
|
5
|
+
from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
|
|
6
6
|
|
|
7
7
|
from .model.config import BaseConfig
|
|
8
8
|
|
|
9
9
|
TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
|
|
10
|
-
TArguments = TypeVarTuple("TArguments")
|
|
10
|
+
TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
|
|
11
11
|
TReturn = TypeVar("TReturn", infer_variance=True)
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
@dataclass(frozen=True)
|
|
15
14
|
class Runner(
|
|
16
|
-
_Runner[
|
|
17
|
-
Generic[TConfig, Unpack[TArguments]
|
|
15
|
+
_Runner[TReturn, TConfig, Unpack[TArguments]],
|
|
16
|
+
Generic[TReturn, TConfig, Unpack[TArguments]],
|
|
18
17
|
):
|
|
19
18
|
@override
|
|
20
|
-
|
|
21
|
-
|
|
19
|
+
@classmethod
|
|
20
|
+
def default_validate_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> None:
|
|
21
|
+
super().default_validate_fn(config, *args)
|
|
22
|
+
|
|
23
|
+
@override
|
|
24
|
+
@classmethod
|
|
25
|
+
def default_info_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> RunInfo:
|
|
26
|
+
run_info = super().default_info_fn(config, *args)
|
|
27
|
+
return {
|
|
28
|
+
**run_info,
|
|
29
|
+
"id": config.id,
|
|
30
|
+
"base_dir": config.directory.project_root,
|
|
31
|
+
}
|
|
@@ -25,14 +25,21 @@ _SIGNUM = int | signal.Signals
|
|
|
25
25
|
_HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
28
|
+
def _resolve_requeue_signals():
|
|
29
|
+
signals: list[signal.Signals] = []
|
|
30
|
+
|
|
31
|
+
if timeout_signal_name := os.environ.get("NSHRUNNER_TIMEOUT_SIGNAL"):
|
|
32
|
+
signals.append(signal.Signals[timeout_signal_name])
|
|
31
33
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
+
if preempt_signal_name := os.environ.get("NSHRUNNER_PREEMPT_SIGNAL"):
|
|
35
|
+
signals.append(signal.Signals[preempt_signal_name])
|
|
34
36
|
|
|
35
|
-
|
|
37
|
+
return signals
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class _SignalConnector(_LightningSignalConnector):
|
|
41
|
+
def _auto_requeue_signals(self) -> list[signal.Signals]:
|
|
42
|
+
signals = _resolve_requeue_signals()
|
|
36
43
|
signals_set = set(signals)
|
|
37
44
|
valid_signals: set[signal.Signals] = signal.valid_signals()
|
|
38
45
|
assert signals_set.issubset(
|
|
@@ -42,25 +49,29 @@ class _SignalConnector(_LightningSignalConnector):
|
|
|
42
49
|
|
|
43
50
|
def _compose_and_register(
|
|
44
51
|
self,
|
|
45
|
-
signum:
|
|
52
|
+
signum: signal.Signals,
|
|
46
53
|
handlers: list[_HANDLER],
|
|
47
54
|
replace_existing: bool = False,
|
|
48
55
|
):
|
|
49
56
|
if self._is_on_windows():
|
|
50
|
-
log.info(
|
|
57
|
+
log.info(
|
|
58
|
+
f"Signal {signum.name} has no handlers or is not supported on Windows."
|
|
59
|
+
)
|
|
51
60
|
return
|
|
52
61
|
|
|
53
62
|
if self._has_already_handler(signum):
|
|
54
63
|
if not replace_existing:
|
|
55
64
|
log.info(
|
|
56
|
-
f"Signal {signum} already has a handler. Adding ours to the existing one."
|
|
65
|
+
f"Signal {signum.name} already has a handler. Adding ours to the existing one."
|
|
57
66
|
)
|
|
58
67
|
handlers.append(signal.getsignal(signum))
|
|
59
68
|
else:
|
|
60
|
-
log.info(
|
|
69
|
+
log.info(
|
|
70
|
+
f"Replacing existing handler for signal {signum.name} with ours."
|
|
71
|
+
)
|
|
61
72
|
|
|
62
73
|
self._register_signal(signum, _HandlersCompose(handlers))
|
|
63
|
-
log.info(f"Registered {len(handlers)} handlers for signal {signum}.")
|
|
74
|
+
log.info(f"Registered {len(handlers)} handlers for signal {signum.name}.")
|
|
64
75
|
|
|
65
76
|
@override
|
|
66
77
|
def register_signal_handlers(self) -> None:
|
nshtrainer/trainer/trainer.py
CHANGED
|
@@ -31,7 +31,7 @@ log = logging.getLogger(__name__)
|
|
|
31
31
|
|
|
32
32
|
def _is_bf16_supported_no_emulation():
|
|
33
33
|
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
|
|
34
|
-
version =
|
|
34
|
+
version = getattr(torch, "version")
|
|
35
35
|
|
|
36
36
|
# Check for ROCm, if true return true, no ROCM_VERSION check required,
|
|
37
37
|
# since it is supported on AMD GPU archs.
|
nshtrainer/typecheck.py
CHANGED
|
@@ -82,6 +82,7 @@ def typecheck_this_module(additional_modules: Sequence[str] = ()):
|
|
|
82
82
|
frame = get_frame(1)
|
|
83
83
|
assert frame is not None, "frame is None"
|
|
84
84
|
calling_module_name = get_frame_package_name(frame)
|
|
85
|
+
assert calling_module_name is not None, "calling_module_name is None"
|
|
85
86
|
|
|
86
87
|
# Typecheck the calling module + any additional modules.
|
|
87
88
|
typecheck_modules((calling_module_name, *additional_modules))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: nshtrainer
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Nima Shoghi
|
|
6
6
|
Author-email: nimashoghi@gmail.com
|
|
@@ -9,10 +9,21 @@ Classifier: Programming Language :: Python :: 3
|
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.10
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.11
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
|
+
Requires-Dist: beartype (>=0.18.5,<0.19.0)
|
|
13
|
+
Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
|
|
14
|
+
Requires-Dist: lightning
|
|
15
|
+
Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
|
|
16
|
+
Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
|
|
12
17
|
Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
|
|
13
|
-
Requires-Dist: nshrunner (>=0.
|
|
18
|
+
Requires-Dist: nshrunner (>=0.5.4,<0.6.0)
|
|
19
|
+
Requires-Dist: numpy
|
|
20
|
+
Requires-Dist: pysnooper
|
|
21
|
+
Requires-Dist: pytorch-lightning
|
|
22
|
+
Requires-Dist: rich
|
|
14
23
|
Requires-Dist: torch
|
|
24
|
+
Requires-Dist: torchmetrics
|
|
15
25
|
Requires-Dist: typing-extensions
|
|
26
|
+
Requires-Dist: wrapt
|
|
16
27
|
Description-Content-Type: text/markdown
|
|
17
28
|
|
|
18
29
|
|
|
@@ -1,22 +1,16 @@
|
|
|
1
|
-
nshtrainer/__init__.py,sha256=
|
|
1
|
+
nshtrainer/__init__.py,sha256=_r7kBmgGSLVfActlqQeupNolrmBu45xUuSS8odt3HL8,2208
|
|
2
2
|
nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
|
|
3
3
|
nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
|
|
4
4
|
nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
|
|
5
5
|
nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
|
|
6
6
|
nshtrainer/_snoop.py,sha256=Rofv1Rd92E0LY40G3A-o9Hu0ZI73RR59wJD5l4Q3PDM,7022
|
|
7
|
-
nshtrainer/_submit/print_environment_info.py,sha256=enbJGl_iHIlhKN8avzKnoZSb0zUQ_fUdnsQ8a_9tbYk,963
|
|
8
|
-
nshtrainer/_submit/session/_output.py,sha256=CNGH5W6_XxAC5-TRvMAMxOHd3fjGpJhK-7RGTDyvMu4,245
|
|
9
|
-
nshtrainer/_submit/session/_script.py,sha256=0AeBgBduDsoIEBrY9kebARiBUEGc50JAD9oE_IDiLnA,3775
|
|
10
|
-
nshtrainer/_submit/session/lsf.py,sha256=p19EP6OhROZxqfRhzeTD7GDmfYaREIKMXMOI8G933FE,14307
|
|
11
|
-
nshtrainer/_submit/session/slurm.py,sha256=JpAjQvck4LjGN8o8fOvIeMuFqrg1cioANoVsX5hU-3g,17594
|
|
12
|
-
nshtrainer/_submit/session/unified.py,sha256=gfh-AtnMyFHzcQOUlhlAR__vaWDk1r9XCivz_t_lHKk,11695
|
|
13
7
|
nshtrainer/actsave/__init__.py,sha256=G1T-fELuGWkVqdhdyoePtj2dTOUtcIOW4VgsXv9JNTA,338
|
|
14
8
|
nshtrainer/actsave/_callback.py,sha256=QoTa60F70f1RxB41VKixN9l5_htfFQxXDPHHSNFreuk,2770
|
|
15
9
|
nshtrainer/actsave/_loader.py,sha256=fAhD32DrJa4onkYfcwc21YIeGEYzOSXCK_HVo9SZLgQ,4604
|
|
16
10
|
nshtrainer/actsave/_saver.py,sha256=0EHmQDhqVxQWRWWSyt03eP1K9ETiACMQYmsZkDMt6HY,9451
|
|
17
|
-
nshtrainer/callbacks/__init__.py,sha256=
|
|
11
|
+
nshtrainer/callbacks/__init__.py,sha256=I6W33ityL9Ko8jjqHh3WH_8miV59SAe9LxInhoqX5XE,1665
|
|
18
12
|
nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
|
|
19
|
-
nshtrainer/callbacks/base.py,sha256=
|
|
13
|
+
nshtrainer/callbacks/base.py,sha256=LrcRUV02bZEKXRIRvhHT9qsvw_kwoWiAdQkVMyKc5NU,3542
|
|
20
14
|
nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
|
|
21
15
|
nshtrainer/callbacks/ema.py,sha256=zKCtvzZFo0ORlwNZHjaMk-sJoxrlTtFWOzR-yGy95W0,12134
|
|
22
16
|
nshtrainer/callbacks/finite_checks.py,sha256=kX3TIJsxyqx0GuLJfYsqVgKU27zwjG9Z8324lyCFtwM,2087
|
|
@@ -30,17 +24,17 @@ nshtrainer/callbacks/print_table.py,sha256=FcA-CBWwMf9c1NNRinvYpZC400RNQxuP28bJf
|
|
|
30
24
|
nshtrainer/callbacks/throughput_monitor.py,sha256=YQLdpX3LGybIiD814yT9yCCVSEXRWf8WwsvVaN5aDBE,1848
|
|
31
25
|
nshtrainer/callbacks/timer.py,sha256=sDXPPcdDKu5xnuK_bjr8plIq9MBuluNJ42Mt9LvPZzc,4610
|
|
32
26
|
nshtrainer/callbacks/wandb_watch.py,sha256=pUpMsNxd03ex1rzOmFw2HzGOXjnQGaH84m8cc2dXo4g,2937
|
|
33
|
-
nshtrainer/config.py,sha256=
|
|
27
|
+
nshtrainer/config.py,sha256=IXOAl_JWFNX9kPTo_iw4Nc3qXqkKrbA6-ZrvTAjqu6A,104
|
|
34
28
|
nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
|
|
35
29
|
nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
|
|
36
30
|
nshtrainer/data/transform.py,sha256=JeGxvytQly8hougrsdMmKG8gJ6qvFPDglJCO4Tp6STk,1795
|
|
37
|
-
nshtrainer/lr_scheduler/__init__.py,sha256=
|
|
38
|
-
nshtrainer/lr_scheduler/_base.py,sha256=
|
|
39
|
-
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=
|
|
40
|
-
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=
|
|
41
|
-
nshtrainer/model/__init__.py,sha256=
|
|
42
|
-
nshtrainer/model/base.py,sha256=
|
|
43
|
-
nshtrainer/model/config.py,sha256=
|
|
31
|
+
nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
|
|
32
|
+
nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
|
|
33
|
+
nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9btIxMRWigUHUTlUYCSw,5221
|
|
34
|
+
nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=2ZdlV0RUMwg2DClzqYHr8_EKT1jZBUlSD39e-XlCsC4,2764
|
|
35
|
+
nshtrainer/model/__init__.py,sha256=y32Hla-5whpzLL2BtCJpBakSp8o-1nQbpO0j_-xq_Po,1864
|
|
36
|
+
nshtrainer/model/base.py,sha256=EMkOtp4YWGPHM0HPSTLbx75T9vlYmXO4XyD725xU70w,21453
|
|
37
|
+
nshtrainer/model/config.py,sha256=6lATW6-Z1SIDgQ1IWrGBVQKTr8DhL5b_rFbJHQz0d5o,66796
|
|
44
38
|
nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
|
|
45
39
|
nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
|
|
46
40
|
nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
|
|
@@ -52,21 +46,20 @@ nshtrainer/nn/__init__.py,sha256=57LPaP3G-BBGD2eGxbBUABNgYl3s_oASwrtOSS4bzTs,133
|
|
|
52
46
|
nshtrainer/nn/mlp.py,sha256=i-dHk0tomO_XlU6cKN4CC4HxTaYb-ukBCAgY1ySXl4I,3963
|
|
53
47
|
nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
|
|
54
48
|
nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
|
|
55
|
-
nshtrainer/nn/nonlinearity.py,sha256=
|
|
56
|
-
nshtrainer/optimizer.py,sha256=
|
|
57
|
-
nshtrainer/runner.py,sha256=
|
|
49
|
+
nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
|
|
50
|
+
nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
|
|
51
|
+
nshtrainer/runner.py,sha256=vyHr0EZ0PBOWZh09BtOOxio-FRQZFbVoL4cdBlI97vY,991
|
|
58
52
|
nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
|
|
59
53
|
nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
|
|
60
54
|
nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
|
|
61
|
-
nshtrainer/trainer/signal_connector.py,sha256=
|
|
62
|
-
nshtrainer/trainer/trainer.py,sha256=
|
|
63
|
-
nshtrainer/typecheck.py,sha256=
|
|
55
|
+
nshtrainer/trainer/signal_connector.py,sha256=QAoPM_C5JJOVQebcrJOimUUD3GHyoeZUqCEAvzZlT4U,8710
|
|
56
|
+
nshtrainer/trainer/trainer.py,sha256=eYEYfY9v70MuorHcSf8nqM7f2CkmUHhpPcjCk4FJD7k,14034
|
|
57
|
+
nshtrainer/typecheck.py,sha256=RGYHxDBcs97E6ayl6Olc43JBZXQolCtMxcLBniVCVBg,4688
|
|
64
58
|
nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
|
|
65
59
|
nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
|
|
66
|
-
nshtrainer/util/singleton.py,sha256=nLhpuMZxl0zdNsnvS97o4ASUnKzCWYEKLzR_j9oP_xs,2208
|
|
67
60
|
nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
|
|
68
61
|
nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
|
|
69
62
|
nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
|
|
70
|
-
nshtrainer-0.
|
|
71
|
-
nshtrainer-0.
|
|
72
|
-
nshtrainer-0.
|
|
63
|
+
nshtrainer-0.2.0.dist-info/METADATA,sha256=cwb3IbKGyJ9HbNSvsORYhCiI61nrDMb1dVm5nE1q_XA,882
|
|
64
|
+
nshtrainer-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
65
|
+
nshtrainer-0.2.0.dist-info/RECORD,,
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
|
-
import sys
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def print_environment_info(log: logging.Logger | None = None):
|
|
7
|
-
if log is None:
|
|
8
|
-
logging.basicConfig(level=logging.INFO)
|
|
9
|
-
log = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
|
-
log_message_lines: list[str] = []
|
|
12
|
-
log_message_lines.append("Python executable: " + sys.executable)
|
|
13
|
-
log_message_lines.append("Python version: " + sys.version)
|
|
14
|
-
log_message_lines.append("Python prefix: " + sys.prefix)
|
|
15
|
-
log_message_lines.append("Python path:")
|
|
16
|
-
for path in sys.path:
|
|
17
|
-
log_message_lines.append(f" {path}")
|
|
18
|
-
|
|
19
|
-
log_message_lines.append("Environment variables:")
|
|
20
|
-
for key, value in os.environ.items():
|
|
21
|
-
log_message_lines.append(f" {key}={value}")
|
|
22
|
-
|
|
23
|
-
log_message_lines.append("Command line arguments:")
|
|
24
|
-
for i, arg in enumerate(sys.argv):
|
|
25
|
-
log_message_lines.append(f" {i}: {arg}")
|
|
26
|
-
|
|
27
|
-
log.critical("\n".join(log_message_lines))
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
if __name__ == "__main__":
|
|
31
|
-
print_environment_info()
|