nshtrainer 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,6 @@
1
1
  import copy
2
2
  import os
3
3
  import re
4
- import signal
5
4
  import socket
6
5
  import string
7
6
  import time
@@ -21,6 +20,7 @@ from typing import (
21
20
  runtime_checkable,
22
21
  )
23
22
 
23
+ import nshconfig as C
24
24
  import numpy as np
25
25
  import torch
26
26
  from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
@@ -39,7 +39,6 @@ from typing_extensions import Self, TypedDict, TypeVar, override
39
39
  from ..callbacks import CallbackConfig
40
40
  from ..callbacks.base import CallbackConfigBase
41
41
  from ..callbacks.wandb_watch import WandbWatchConfig
42
- from ..config import Field, TypedConfig
43
42
  from ..util.slurm import parse_slurm_node_list
44
43
 
45
44
  log = getLogger(__name__)
@@ -49,7 +48,7 @@ class IdSeedWarning(Warning):
49
48
  pass
50
49
 
51
50
 
52
- class BaseProfilerConfig(TypedConfig, ABC):
51
+ class BaseProfilerConfig(C.Config, ABC):
53
52
  dirpath: str | Path | None = None
54
53
  """
55
54
  Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
@@ -200,11 +199,11 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
200
199
 
201
200
  ProfilerConfig: TypeAlias = Annotated[
202
201
  SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
203
- Field(discriminator="kind"),
202
+ C.Field(discriminator="kind"),
204
203
  ]
205
204
 
206
205
 
207
- class EnvironmentClassInformationConfig(TypedConfig):
206
+ class EnvironmentClassInformationConfig(C.Config):
208
207
  name: str
209
208
  module: str
210
209
  full_name: str
@@ -213,7 +212,7 @@ class EnvironmentClassInformationConfig(TypedConfig):
213
212
  source_file_path: Path | None = None
214
213
 
215
214
 
216
- class EnvironmentSLURMInformationConfig(TypedConfig):
215
+ class EnvironmentSLURMInformationConfig(C.Config):
217
216
  hostname: str
218
217
  hostnames: list[str]
219
218
  job_id: str
@@ -271,7 +270,7 @@ class EnvironmentSLURMInformationConfig(TypedConfig):
271
270
  return None
272
271
 
273
272
 
274
- class EnvironmentLSFInformationConfig(TypedConfig):
273
+ class EnvironmentLSFInformationConfig(C.Config):
275
274
  hostname: str
276
275
  hostnames: list[str]
277
276
  job_id: str
@@ -328,7 +327,7 @@ class EnvironmentLSFInformationConfig(TypedConfig):
328
327
  return None
329
328
 
330
329
 
331
- class EnvironmentLinuxEnvironmentConfig(TypedConfig):
330
+ class EnvironmentLinuxEnvironmentConfig(C.Config):
332
331
  """
333
332
  Information about the Linux environment (e.g., current user, hostname, etc.)
334
333
  """
@@ -347,9 +346,25 @@ class EnvironmentLinuxEnvironmentConfig(TypedConfig):
347
346
  load_avg: tuple[float, float, float] | None = None
348
347
 
349
348
 
350
- class EnvironmentConfig(TypedConfig):
349
+ class EnvironmentSnapshotConfig(C.Config):
350
+ snapshot_dir: Path | None = None
351
+ modules: list[str] | None = None
352
+
353
+ @classmethod
354
+ def from_current_environment(cls):
355
+ draft = cls.draft()
356
+ if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
357
+ draft.snapshot_dir = Path(snapshot_dir)
358
+ if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
359
+ draft.modules = modules.split(",")
360
+ return draft.finalize()
361
+
362
+
363
+ class EnvironmentConfig(C.Config):
351
364
  cwd: Path | None = None
352
365
 
366
+ snapshot: EnvironmentSnapshotConfig | None = None
367
+
353
368
  python_executable: Path | None = None
354
369
  python_path: list[Path] | None = None
355
370
  python_version: str | None = None
@@ -372,7 +387,7 @@ class EnvironmentConfig(TypedConfig):
372
387
  seed_workers: bool | None = None
373
388
 
374
389
 
375
- class BaseLoggerConfig(TypedConfig, ABC):
390
+ class BaseLoggerConfig(C.Config, ABC):
376
391
  enabled: bool = True
377
392
  """Enable this logger."""
378
393
 
@@ -426,7 +441,7 @@ def _wandb_available():
426
441
  class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
427
442
  kind: Literal["wandb"] = "wandb"
428
443
 
429
- enabled: bool = Field(default_factory=lambda: _wandb_available())
444
+ enabled: bool = C.Field(default_factory=lambda: _wandb_available())
430
445
  """Enable WandB logging."""
431
446
 
432
447
  priority: int = 2
@@ -543,7 +558,7 @@ def _tensorboard_available():
543
558
  class TensorboardLoggerConfig(BaseLoggerConfig):
544
559
  kind: Literal["tensorboard"] = "tensorboard"
545
560
 
546
- enabled: bool = Field(default_factory=lambda: _tensorboard_available())
561
+ enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
547
562
  """Enable TensorBoard logging."""
548
563
 
549
564
  priority: int = 2
@@ -589,7 +604,7 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
589
604
 
590
605
  LoggerConfig: TypeAlias = Annotated[
591
606
  WandbLoggerConfig | CSVLoggerConfig | TensorboardLoggerConfig,
592
- Field(discriminator="kind"),
607
+ C.Field(discriminator="kind"),
593
608
  ]
594
609
 
595
610
 
@@ -684,7 +699,7 @@ class LoggingConfig(CallbackConfigBase):
684
699
  yield from logger.construct_callbacks(root_config)
685
700
 
686
701
 
687
- class GradientClippingConfig(TypedConfig):
702
+ class GradientClippingConfig(C.Config):
688
703
  enabled: bool = True
689
704
  """Enable gradient clipping."""
690
705
  value: int | float
@@ -719,41 +734,6 @@ class OptimizationConfig(CallbackConfigBase):
719
734
  ).construct_callbacks(root_config)
720
735
 
721
736
 
722
- LogLevel: TypeAlias = Literal[
723
- "CRITICAL", "FATAL", "ERROR", "WARN", "WARNING", "INFO", "DEBUG"
724
- ]
725
-
726
-
727
- class PythonLogging(TypedConfig):
728
- log_level: LogLevel | None = None
729
- """Log level to use for the Python logger (or None to use the default)."""
730
-
731
- rich: bool = False
732
- """If enabled, will use the rich library to format the Python logger output."""
733
- rich_tracebacks: bool = True
734
- """If enabled, will use the rich library to format the Python logger tracebacks."""
735
-
736
- lovely_tensors: bool = False
737
- """If enabled, will use the lovely-tensors library to format PyTorch tensors. False by default as it causes issues when used with `torch.vmap`."""
738
- lovely_numpy: bool = False
739
- """If enabled, will use the lovely-numpy library to format numpy arrays. False by default as it causes some issues with other libaries."""
740
-
741
- def pretty_(
742
- self,
743
- *,
744
- log_level: LogLevel | None = "INFO",
745
- torch: bool = True,
746
- numpy: bool = True,
747
- rich: bool = True,
748
- rich_tracebacks: bool = True,
749
- ):
750
- self.log_level = log_level
751
- self.lovely_tensors = torch
752
- self.lovely_numpy = numpy
753
- self.rich = rich
754
- self.rich_tracebacks = rich_tracebacks
755
-
756
-
757
737
  TPlugin = TypeVar(
758
738
  "TPlugin",
759
739
  Precision,
@@ -813,7 +793,7 @@ StrategyLiteral: TypeAlias = Literal[
813
793
  ]
814
794
 
815
795
 
816
- class CheckpointLoadingConfig(TypedConfig):
796
+ class CheckpointLoadingConfig(C.Config):
817
797
  path: Literal["best", "last", "hpc"] | str | Path | None = None
818
798
  """
819
799
  Checkpoint path to use when loading a checkpoint.
@@ -825,7 +805,7 @@ class CheckpointLoadingConfig(TypedConfig):
825
805
  """
826
806
 
827
807
 
828
- class DirectoryConfig(TypedConfig):
808
+ class DirectoryConfig(C.Config):
829
809
  project_root: Path | None = None
830
810
  """
831
811
  Root directory for this project.
@@ -901,7 +881,7 @@ class DirectoryConfig(TypedConfig):
901
881
  return log_dir
902
882
 
903
883
 
904
- class ReproducibilityConfig(TypedConfig):
884
+ class ReproducibilityConfig(C.Config):
905
885
  deterministic: bool | Literal["warn"] | None = None
906
886
  """
907
887
  If ``True``, sets whether PyTorch operations must use deterministic algorithms.
@@ -1116,7 +1096,7 @@ CheckpointCallbackConfig: TypeAlias = Annotated[
1116
1096
  ModelCheckpointCallbackConfig
1117
1097
  | LatestEpochCheckpointCallbackConfig
1118
1098
  | OnExceptionCheckpointCallbackConfig,
1119
- Field(discriminator="kind"),
1099
+ C.Field(discriminator="kind"),
1120
1100
  ]
1121
1101
 
1122
1102
 
@@ -1514,7 +1494,7 @@ class ActSaveConfig(CallbackConfigBase):
1514
1494
  return [ActSaveCallback()]
1515
1495
 
1516
1496
 
1517
- class SanityCheckingConfig(TypedConfig):
1497
+ class SanityCheckingConfig(C.Config):
1518
1498
  reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
1519
1499
  """
1520
1500
  If enabled, will do some sanity checks if the `ReduceLROnPlateau` scheduler is used:
@@ -1524,7 +1504,7 @@ class SanityCheckingConfig(TypedConfig):
1524
1504
  """
1525
1505
 
1526
1506
 
1527
- class TrainerConfig(TypedConfig):
1507
+ class TrainerConfig(C.Config):
1528
1508
  checkpoint_loading: CheckpointLoadingConfig = CheckpointLoadingConfig()
1529
1509
  """Checkpoint loading configuration options."""
1530
1510
 
@@ -1739,87 +1719,7 @@ class TrainerConfig(TypedConfig):
1739
1719
  """If enabled, will set the torch float32 matmul precision to the specified value. Useful for faster training on Ampere+ GPUs."""
1740
1720
 
1741
1721
 
1742
- class SeedConfig(TypedConfig):
1743
- seed: int
1744
- """Seed for the random number generator."""
1745
-
1746
- seed_workers: bool = False
1747
- """Whether to seed the workers of the dataloader."""
1748
-
1749
-
1750
- Signal: TypeAlias = Literal[
1751
- "SIGHUP",
1752
- "SIGINT",
1753
- "SIGQUIT",
1754
- "SIGILL",
1755
- "SIGTRAP",
1756
- "SIGABRT",
1757
- "SIGBUS",
1758
- "SIGFPE",
1759
- "SIGKILL",
1760
- "SIGUSR1",
1761
- "SIGSEGV",
1762
- "SIGUSR2",
1763
- "SIGPIPE",
1764
- "SIGALRM",
1765
- "SIGTERM",
1766
- "SIGCHLD",
1767
- "SIGCONT",
1768
- "SIGSTOP",
1769
- "SIGTSTP",
1770
- "SIGTTIN",
1771
- "SIGTTOU",
1772
- "SIGURG",
1773
- "SIGXCPU",
1774
- "SIGXFSZ",
1775
- "SIGVTALRM",
1776
- "SIGPROF",
1777
- "SIGWINCH",
1778
- "SIGIO",
1779
- "SIGPWR",
1780
- "SIGSYS",
1781
- "SIGRTMIN",
1782
- "SIGRTMAX",
1783
- ]
1784
-
1785
-
1786
- class SubmitConfig(TypedConfig):
1787
- auto_requeue_signals: list[Signal] = [
1788
- # "SIGUSR1",
1789
- # On SIGURG:
1790
- # Important note from https://amrex-astro.github.io/workflow/olcf-workflow.html:
1791
- # We can also ask the job manager to send a warning signal some amount of time before the allocation expires by passing -wa 'signal' and -wt '[hour:]minute' to bsub. We can then have bash create a dump_and_stop file when it receives the signal, which will tell Castro to output a checkpoint file and exit cleanly after it finishes the current timestep. An important detail that I couldn't find documented anywhere is that the job manager sends the signal to all the processes in the job, not just the submission script, and we have to use a signal that is ignored by default so Castro doesn't immediately crash upon receiving it. SIGCHLD, SIGURG, and SIGWINCH are the only signals that fit this requirement and of these, SIGURG is the least likely to be triggered by other events.
1792
- "SIGURG"
1793
- ]
1794
- """Signals that will trigger an automatic requeue of the job."""
1795
-
1796
- def _resolved_auto_requeue_signals(self) -> list[signal.Signals]:
1797
- return [getattr(signal.Signals, sig) for sig in self.auto_requeue_signals]
1798
-
1799
-
1800
- class RunnerConfig(TypedConfig):
1801
- python_logging: PythonLogging = PythonLogging()
1802
- """Python logging configuration options."""
1803
-
1804
- seed: SeedConfig = SeedConfig(seed=0)
1805
- """Seed everything configuration options."""
1806
-
1807
- submit: SubmitConfig = SubmitConfig()
1808
- """Submit (e.g., SLURM or LSF) configuration options."""
1809
-
1810
- dump_run_information: bool = True
1811
- """
1812
- If enabled, will dump different bits of run information to the output directory before starting the run.
1813
- This includes:
1814
- - Run config
1815
- - Full set of environment variables
1816
- """
1817
-
1818
- additional_env_vars: dict[str, str] = {}
1819
- """Additional environment variables to set when running the script."""
1820
-
1821
-
1822
- class MetricConfig(TypedConfig):
1722
+ class MetricConfig(C.Config):
1823
1723
  name: str
1824
1724
  """The name of the primary metric."""
1825
1725
 
@@ -1851,8 +1751,8 @@ class MetricConfig(TypedConfig):
1851
1751
  PrimaryMetricConfig: TypeAlias = MetricConfig
1852
1752
 
1853
1753
 
1854
- class BaseConfig(TypedConfig):
1855
- id: str = Field(default_factory=lambda: BaseConfig.generate_id())
1754
+ class BaseConfig(C.Config):
1755
+ id: str = C.Field(default_factory=lambda: BaseConfig.generate_id())
1856
1756
  """ID of the run."""
1857
1757
  name: str | None = None
1858
1758
  """Run name."""
@@ -1867,15 +1767,13 @@ class BaseConfig(TypedConfig):
1867
1767
 
1868
1768
  debug: bool = False
1869
1769
  """Whether to run in debug mode. This will enable debug logging and enable debug code paths."""
1870
- environment: Annotated[EnvironmentConfig, Field(repr=False)] = EnvironmentConfig()
1770
+ environment: Annotated[EnvironmentConfig, C.Field(repr=False)] = EnvironmentConfig()
1871
1771
  """A snapshot of the current environment information (e.g. python version, slurm info, etc.). This is automatically populated by the run script."""
1872
1772
 
1873
1773
  directory: DirectoryConfig = DirectoryConfig()
1874
1774
  """Directory configuration options."""
1875
1775
  trainer: TrainerConfig = TrainerConfig()
1876
1776
  """PyTorch Lightning trainer configuration options. Check Lightning's `Trainer` documentation for more information."""
1877
- runner: RunnerConfig = RunnerConfig()
1878
- """`ll.Runner` configuration options."""
1879
1777
 
1880
1778
  primary_metric: PrimaryMetricConfig | None = None
1881
1779
  """Primary metric configuration options. This is used in the following ways:
@@ -1,14 +1,13 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Annotated, Literal
3
3
 
4
+ import nshconfig as C
4
5
  import torch
5
6
  import torch.nn as nn
6
7
  from typing_extensions import override
7
8
 
8
- from ..config import Field, TypedConfig
9
9
 
10
-
11
- class BaseNonlinearityConfig(TypedConfig, ABC):
10
+ class BaseNonlinearityConfig(C.Config, ABC):
12
11
  @abstractmethod
13
12
  def create_module(self) -> nn.Module:
14
13
  pass
@@ -153,5 +152,5 @@ NonlinearityConfig = Annotated[
153
152
  | SiLUNonlinearityConfig
154
153
  | MishNonlinearityConfig
155
154
  | SwiGLUNonlinearityConfig,
156
- Field(discriminator="name"),
155
+ C.Field(discriminator="name"),
157
156
  ]
nshtrainer/optimizer.py CHANGED
@@ -2,14 +2,13 @@ from abc import ABC, abstractmethod
2
2
  from collections.abc import Iterable
3
3
  from typing import Annotated, Any, Literal, TypeAlias
4
4
 
5
+ import nshconfig as C
5
6
  import torch.nn as nn
6
7
  from torch.optim import Optimizer
7
8
  from typing_extensions import override
8
9
 
9
- from .config import Field, TypedConfig
10
10
 
11
-
12
- class OptimizerConfigBase(TypedConfig, ABC):
11
+ class OptimizerConfigBase(C.Config, ABC):
13
12
  @abstractmethod
14
13
  def create_optimizer(
15
14
  self,
@@ -56,7 +55,4 @@ class AdamWConfig(OptimizerConfigBase):
56
55
  )
57
56
 
58
57
 
59
- OptimizerConfig: TypeAlias = Annotated[
60
- AdamWConfig,
61
- Field(discriminator="name"),
62
- ]
58
+ OptimizerConfig: TypeAlias = Annotated[AdamWConfig, C.Field(discriminator="name")]
nshtrainer/runner.py CHANGED
@@ -1,21 +1,31 @@
1
- from dataclasses import dataclass
2
1
  from typing import Generic
3
2
 
3
+ from nshrunner import RunInfo
4
4
  from nshrunner import Runner as _Runner
5
- from typing_extensions import Concatenate, TypeVar, TypeVarTuple, Unpack, override
5
+ from typing_extensions import TypeVar, TypeVarTuple, Unpack, override
6
6
 
7
7
  from .model.config import BaseConfig
8
8
 
9
9
  TConfig = TypeVar("TConfig", bound=BaseConfig, infer_variance=True)
10
- TArguments = TypeVarTuple("TArguments")
10
+ TArguments = TypeVarTuple("TArguments", default=Unpack[tuple[()]])
11
11
  TReturn = TypeVar("TReturn", infer_variance=True)
12
12
 
13
13
 
14
- @dataclass(frozen=True)
15
14
  class Runner(
16
- _Runner[Unpack[tuple[TConfig, Unpack[TArguments]]], TReturn],
17
- Generic[TConfig, Unpack[TArguments], TReturn],
15
+ _Runner[TReturn, TConfig, Unpack[TArguments]],
16
+ Generic[TReturn, TConfig, Unpack[TArguments]],
18
17
  ):
19
18
  @override
20
- def default_validate_fn():
21
- pass
19
+ @classmethod
20
+ def default_validate_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> None:
21
+ super().default_validate_fn(config, *args)
22
+
23
+ @override
24
+ @classmethod
25
+ def default_info_fn(cls, config: TConfig, *args: Unpack[TArguments]) -> RunInfo:
26
+ run_info = super().default_info_fn(config, *args)
27
+ return {
28
+ **run_info,
29
+ "id": config.id,
30
+ "base_dir": config.directory.project_root,
31
+ }
@@ -25,14 +25,21 @@ _SIGNUM = int | signal.Signals
25
25
  _HANDLER: TypeAlias = Callable[[_SIGNUM, FrameType], Any] | int | signal.Handlers | None
26
26
 
27
27
 
28
- class _SignalConnector(_LightningSignalConnector):
29
- def _auto_requeue_signals(self) -> list[signal.Signals]:
30
- from ..model.base import BaseConfig
28
+ def _resolve_requeue_signals():
29
+ signals: list[signal.Signals] = []
30
+
31
+ if timeout_signal_name := os.environ.get("NSHRUNNER_TIMEOUT_SIGNAL"):
32
+ signals.append(signal.Signals[timeout_signal_name])
31
33
 
32
- if not isinstance(config := self.trainer.lightning_module.hparams, BaseConfig):
33
- return []
34
+ if preempt_signal_name := os.environ.get("NSHRUNNER_PREEMPT_SIGNAL"):
35
+ signals.append(signal.Signals[preempt_signal_name])
34
36
 
35
- signals = config.runner.submit._resolved_auto_requeue_signals()
37
+ return signals
38
+
39
+
40
+ class _SignalConnector(_LightningSignalConnector):
41
+ def _auto_requeue_signals(self) -> list[signal.Signals]:
42
+ signals = _resolve_requeue_signals()
36
43
  signals_set = set(signals)
37
44
  valid_signals: set[signal.Signals] = signal.valid_signals()
38
45
  assert signals_set.issubset(
@@ -42,25 +49,29 @@ class _SignalConnector(_LightningSignalConnector):
42
49
 
43
50
  def _compose_and_register(
44
51
  self,
45
- signum: _SIGNUM,
52
+ signum: signal.Signals,
46
53
  handlers: list[_HANDLER],
47
54
  replace_existing: bool = False,
48
55
  ):
49
56
  if self._is_on_windows():
50
- log.info(f"Signal {signum} has no handlers or is not supported on Windows.")
57
+ log.info(
58
+ f"Signal {signum.name} has no handlers or is not supported on Windows."
59
+ )
51
60
  return
52
61
 
53
62
  if self._has_already_handler(signum):
54
63
  if not replace_existing:
55
64
  log.info(
56
- f"Signal {signum} already has a handler. Adding ours to the existing one."
65
+ f"Signal {signum.name} already has a handler. Adding ours to the existing one."
57
66
  )
58
67
  handlers.append(signal.getsignal(signum))
59
68
  else:
60
- log.info(f"Replacing existing handler for signal {signum} with ours.")
69
+ log.info(
70
+ f"Replacing existing handler for signal {signum.name} with ours."
71
+ )
61
72
 
62
73
  self._register_signal(signum, _HandlersCompose(handlers))
63
- log.info(f"Registered {len(handlers)} handlers for signal {signum}.")
74
+ log.info(f"Registered {len(handlers)} handlers for signal {signum.name}.")
64
75
 
65
76
  @override
66
77
  def register_signal_handlers(self) -> None:
@@ -31,7 +31,7 @@ log = logging.getLogger(__name__)
31
31
 
32
32
  def _is_bf16_supported_no_emulation():
33
33
  r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
34
- version = cast(Any, torch.version)
34
+ version = getattr(torch, "version")
35
35
 
36
36
  # Check for ROCm, if true return true, no ROCM_VERSION check required,
37
37
  # since it is supported on AMD GPU archs.
nshtrainer/typecheck.py CHANGED
@@ -82,6 +82,7 @@ def typecheck_this_module(additional_modules: Sequence[str] = ()):
82
82
  frame = get_frame(1)
83
83
  assert frame is not None, "frame is None"
84
84
  calling_module_name = get_frame_package_name(frame)
85
+ assert calling_module_name is not None, "calling_module_name is None"
85
86
 
86
87
  # Typecheck the calling module + any additional modules.
87
88
  typecheck_modules((calling_module_name, *additional_modules))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -9,10 +9,21 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: beartype (>=0.18.5,<0.19.0)
13
+ Requires-Dist: jaxtyping (>=0.2.33,<0.3.0)
14
+ Requires-Dist: lightning
15
+ Requires-Dist: lovely-numpy (>=0.2.13,<0.3.0)
16
+ Requires-Dist: lovely-tensors (>=0.1.16,<0.2.0)
12
17
  Requires-Dist: nshconfig (>=0.2.0,<0.3.0)
13
- Requires-Dist: nshrunner (>=0.1.0,<0.2.0)
18
+ Requires-Dist: nshrunner (>=0.5.4,<0.6.0)
19
+ Requires-Dist: numpy
20
+ Requires-Dist: pysnooper
21
+ Requires-Dist: pytorch-lightning
22
+ Requires-Dist: rich
14
23
  Requires-Dist: torch
24
+ Requires-Dist: torchmetrics
15
25
  Requires-Dist: typing-extensions
26
+ Requires-Dist: wrapt
16
27
  Description-Content-Type: text/markdown
17
28
 
18
29
 
@@ -1,22 +1,16 @@
1
- nshtrainer/__init__.py,sha256=o39TbnjwUYzE4POcncUiDx02Ey-Hzx8UGuwJDjMcKZU,2971
1
+ nshtrainer/__init__.py,sha256=_r7kBmgGSLVfActlqQeupNolrmBu45xUuSS8odt3HL8,2208
2
2
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
3
3
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
4
4
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
5
5
  nshtrainer/_experimental/flops/module_tracker.py,sha256=bUL-IRTd0aF_DwmXkZjHZAA31p4ZEhyqhc26XWKQUUY,4922
6
6
  nshtrainer/_snoop.py,sha256=Rofv1Rd92E0LY40G3A-o9Hu0ZI73RR59wJD5l4Q3PDM,7022
7
- nshtrainer/_submit/print_environment_info.py,sha256=enbJGl_iHIlhKN8avzKnoZSb0zUQ_fUdnsQ8a_9tbYk,963
8
- nshtrainer/_submit/session/_output.py,sha256=CNGH5W6_XxAC5-TRvMAMxOHd3fjGpJhK-7RGTDyvMu4,245
9
- nshtrainer/_submit/session/_script.py,sha256=0AeBgBduDsoIEBrY9kebARiBUEGc50JAD9oE_IDiLnA,3775
10
- nshtrainer/_submit/session/lsf.py,sha256=p19EP6OhROZxqfRhzeTD7GDmfYaREIKMXMOI8G933FE,14307
11
- nshtrainer/_submit/session/slurm.py,sha256=JpAjQvck4LjGN8o8fOvIeMuFqrg1cioANoVsX5hU-3g,17594
12
- nshtrainer/_submit/session/unified.py,sha256=gfh-AtnMyFHzcQOUlhlAR__vaWDk1r9XCivz_t_lHKk,11695
13
7
  nshtrainer/actsave/__init__.py,sha256=G1T-fELuGWkVqdhdyoePtj2dTOUtcIOW4VgsXv9JNTA,338
14
8
  nshtrainer/actsave/_callback.py,sha256=QoTa60F70f1RxB41VKixN9l5_htfFQxXDPHHSNFreuk,2770
15
9
  nshtrainer/actsave/_loader.py,sha256=fAhD32DrJa4onkYfcwc21YIeGEYzOSXCK_HVo9SZLgQ,4604
16
10
  nshtrainer/actsave/_saver.py,sha256=0EHmQDhqVxQWRWWSyt03eP1K9ETiACMQYmsZkDMt6HY,9451
17
- nshtrainer/callbacks/__init__.py,sha256=ohE_MO_kX1o4SZwcipIXUA9m7XYcijEKJtGcoU8dTkY,1667
11
+ nshtrainer/callbacks/__init__.py,sha256=I6W33ityL9Ko8jjqHh3WH_8miV59SAe9LxInhoqX5XE,1665
18
12
  nshtrainer/callbacks/_throughput_monitor_callback.py,sha256=aJo_11rc4lo0IYOd-kHmPDtzdC4ctgXyRudkRJqH4m4,23184
19
- nshtrainer/callbacks/base.py,sha256=WESZz1VSTl1xSGVXBmxFqWwbLxXcJp97jpg9zrE0EsY,3560
13
+ nshtrainer/callbacks/base.py,sha256=LrcRUV02bZEKXRIRvhHT9qsvw_kwoWiAdQkVMyKc5NU,3542
20
14
  nshtrainer/callbacks/early_stopping.py,sha256=jriSU761wf_qTJ9Bos0D3h5aDvZHYpRqK62Ne8aWp5I,3768
21
15
  nshtrainer/callbacks/ema.py,sha256=zKCtvzZFo0ORlwNZHjaMk-sJoxrlTtFWOzR-yGy95W0,12134
22
16
  nshtrainer/callbacks/finite_checks.py,sha256=kX3TIJsxyqx0GuLJfYsqVgKU27zwjG9Z8324lyCFtwM,2087
@@ -30,17 +24,17 @@ nshtrainer/callbacks/print_table.py,sha256=FcA-CBWwMf9c1NNRinvYpZC400RNQxuP28bJf
30
24
  nshtrainer/callbacks/throughput_monitor.py,sha256=YQLdpX3LGybIiD814yT9yCCVSEXRWf8WwsvVaN5aDBE,1848
31
25
  nshtrainer/callbacks/timer.py,sha256=sDXPPcdDKu5xnuK_bjr8plIq9MBuluNJ42Mt9LvPZzc,4610
32
26
  nshtrainer/callbacks/wandb_watch.py,sha256=pUpMsNxd03ex1rzOmFw2HzGOXjnQGaH84m8cc2dXo4g,2937
33
- nshtrainer/config.py,sha256=0Fj5w-ry0BRl2_zJI6jwCnmMWE3p_eD8_Wn-NyFkTqU,10442
27
+ nshtrainer/config.py,sha256=IXOAl_JWFNX9kPTo_iw4Nc3qXqkKrbA6-ZrvTAjqu6A,104
34
28
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
35
29
  nshtrainer/data/balanced_batch_sampler.py,sha256=bcJBcQjh1hB1yKF_xSlT9AtEWv0BJjYc1CuH2BF-ea8,4392
36
30
  nshtrainer/data/transform.py,sha256=JeGxvytQly8hougrsdMmKG8gJ6qvFPDglJCO4Tp6STk,1795
37
- nshtrainer/lr_scheduler/__init__.py,sha256=GNGmkcJD3jgCMk7pfaanAYrKz9957qkx6_Q0rssiHK0,738
38
- nshtrainer/lr_scheduler/_base.py,sha256=1tWMABevKZAuGhJN8Me2E9eqEyqoLtsG0bADPjED7a4,3752
39
- nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=VhsxZJ_Mw9zjkAGunFQ1KRub5_QM5NRqaEFWtmedFp8,5212
40
- nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=Ct-uLo8Q4t7lJ_HwoLRhNmudnCw4cSnblpBEg22aVTI,2691
41
- nshtrainer/model/__init__.py,sha256=PdvZkpAVkqvCLipGJvEHFU3WxnSMxYpvtuOkvLIenxg,2078
42
- nshtrainer/model/base.py,sha256=bhngGHxr0suQB9Ezi_3d5JgDWYqS_yPgGJZrGmc1TnI,23571
43
- nshtrainer/model/config.py,sha256=RMDdrbtvwm5vTFPxQ2x1hqiBIEEE-OAknhF6KTWfkkk,70293
31
+ nshtrainer/lr_scheduler/__init__.py,sha256=uEvgaFAs-4s_bAEMaildy0GT6OvgpgOEKTuzqutESHE,736
32
+ nshtrainer/lr_scheduler/_base.py,sha256=7xOIuxQ86YHbFWG5a3gX46emQj1WN_LaY4-i0Q1TDBg,3659
33
+ nshtrainer/lr_scheduler/linear_warmup_cosine.py,sha256=mn6cyizyI_stkXtg6zxIEGF9btIxMRWigUHUTlUYCSw,5221
34
+ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=2ZdlV0RUMwg2DClzqYHr8_EKT1jZBUlSD39e-XlCsC4,2764
35
+ nshtrainer/model/__init__.py,sha256=y32Hla-5whpzLL2BtCJpBakSp8o-1nQbpO0j_-xq_Po,1864
36
+ nshtrainer/model/base.py,sha256=EMkOtp4YWGPHM0HPSTLbx75T9vlYmXO4XyD725xU70w,21453
37
+ nshtrainer/model/config.py,sha256=6lATW6-Z1SIDgQ1IWrGBVQKTr8DhL5b_rFbJHQz0d5o,66796
44
38
  nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
45
39
  nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
46
40
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -52,21 +46,20 @@ nshtrainer/nn/__init__.py,sha256=57LPaP3G-BBGD2eGxbBUABNgYl3s_oASwrtOSS4bzTs,133
52
46
  nshtrainer/nn/mlp.py,sha256=i-dHk0tomO_XlU6cKN4CC4HxTaYb-ukBCAgY1ySXl4I,3963
53
47
  nshtrainer/nn/module_dict.py,sha256=NOY0B6WDTnktyWH4GthsprMQo0bpehC-hCq9SfD8paE,2329
54
48
  nshtrainer/nn/module_list.py,sha256=fb2u5Rqdjff8Pekyr9hkCPkBorQ-fldzzFAjsgWAm30,1719
55
- nshtrainer/nn/nonlinearity.py,sha256=IhIR8NCTY3Np9dMDnUouERR8ZhWpK3S0hTbT0i8HezU,3645
56
- nshtrainer/optimizer.py,sha256=JiLNRtcfYxyhAab1Z1QcEzmrX9S_JyrBS67TXy12kXI,1557
57
- nshtrainer/runner.py,sha256=9HsYB58aasY9RVvya_gPECDs_MBhM1fl4cbM3iJYTDc,600
49
+ nshtrainer/nn/nonlinearity.py,sha256=owtU4kh4G98psD0axOJWVfBhm-OtJVgFM-TXSHmbNPU,3625
50
+ nshtrainer/optimizer.py,sha256=kuJEA1pvB3y1FcsfhAoOJujVqEZqFHlmYO8GW6JeA1g,1527
51
+ nshtrainer/runner.py,sha256=vyHr0EZ0PBOWZh09BtOOxio-FRQZFbVoL4cdBlI97vY,991
58
52
  nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
59
53
  nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
60
54
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
61
- nshtrainer/trainer/signal_connector.py,sha256=aGg6kRiHiqtAdGlEvEvGLmOy7AvRHTSkXdTmZpRXbjU,8435
62
- nshtrainer/trainer/trainer.py,sha256=oi8KdHF1AdZ54KFbCFAEI7W-C7qRtRe-KtOjNwBuS3M,14033
63
- nshtrainer/typecheck.py,sha256=CFkmPIxCU24CHk_7_pykb-Y1PRNhpLgsVZw1zuuOS_U,4614
55
+ nshtrainer/trainer/signal_connector.py,sha256=QAoPM_C5JJOVQebcrJOimUUD3GHyoeZUqCEAvzZlT4U,8710
56
+ nshtrainer/trainer/trainer.py,sha256=eYEYfY9v70MuorHcSf8nqM7f2CkmUHhpPcjCk4FJD7k,14034
57
+ nshtrainer/typecheck.py,sha256=RGYHxDBcs97E6ayl6Olc43JBZXQolCtMxcLBniVCVBg,4688
64
58
  nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
65
59
  nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
66
- nshtrainer/util/singleton.py,sha256=nLhpuMZxl0zdNsnvS97o4ASUnKzCWYEKLzR_j9oP_xs,2208
67
60
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
68
61
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
69
62
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
70
- nshtrainer-0.1.0.dist-info/METADATA,sha256=3zdNPxyB-I6Gudq2gTaU0crdgmDCcGCp6Zudef0DtuM,529
71
- nshtrainer-0.1.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
72
- nshtrainer-0.1.0.dist-info/RECORD,,
63
+ nshtrainer-0.2.0.dist-info/METADATA,sha256=cwb3IbKGyJ9HbNSvsORYhCiI61nrDMb1dVm5nE1q_XA,882
64
+ nshtrainer-0.2.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
65
+ nshtrainer-0.2.0.dist-info/RECORD,,
@@ -1,31 +0,0 @@
1
- import logging
2
- import os
3
- import sys
4
-
5
-
6
- def print_environment_info(log: logging.Logger | None = None):
7
- if log is None:
8
- logging.basicConfig(level=logging.INFO)
9
- log = logging.getLogger(__name__)
10
-
11
- log_message_lines: list[str] = []
12
- log_message_lines.append("Python executable: " + sys.executable)
13
- log_message_lines.append("Python version: " + sys.version)
14
- log_message_lines.append("Python prefix: " + sys.prefix)
15
- log_message_lines.append("Python path:")
16
- for path in sys.path:
17
- log_message_lines.append(f" {path}")
18
-
19
- log_message_lines.append("Environment variables:")
20
- for key, value in os.environ.items():
21
- log_message_lines.append(f" {key}={value}")
22
-
23
- log_message_lines.append("Command line arguments:")
24
- for i, arg in enumerate(sys.argv):
25
- log_message_lines.append(f" {i}: {arg}")
26
-
27
- log.critical("\n".join(log_message_lines))
28
-
29
-
30
- if __name__ == "__main__":
31
- print_environment_info()