nshtrainer 1.1.0__py3-none-any.whl → 1.1.1b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +0 -3
- nshtrainer/configs/__init__.py +3 -3
- nshtrainer/data/datamodule.py +2 -2
- nshtrainer/model/base.py +2 -9
- nshtrainer/trainer/_log_hparams.py +85 -0
- nshtrainer/trainer/trainer.py +4 -0
- {nshtrainer-1.1.0.dist-info → nshtrainer-1.1.1b1.dist-info}/METADATA +1 -1
- {nshtrainer-1.1.0.dist-info → nshtrainer-1.1.1b1.dist-info}/RECORD +9 -8
- {nshtrainer-1.1.0.dist-info → nshtrainer-1.1.1b1.dist-info}/WHEEL +0 -0
nshtrainer/__init__.py
CHANGED
@@ -14,9 +14,6 @@ from .metrics import MetricConfig as MetricConfig
|
|
14
14
|
from .model import LightningModuleBase as LightningModuleBase
|
15
15
|
from .trainer import Trainer as Trainer
|
16
16
|
from .trainer import TrainerConfig as TrainerConfig
|
17
|
-
from .trainer import accelerator_registry as accelerator_registry
|
18
|
-
from .trainer import callback_registry as callback_registry
|
19
|
-
from .trainer import plugin_registry as plugin_registry
|
20
17
|
|
21
18
|
try:
|
22
19
|
from . import configs as configs
|
nshtrainer/configs/__init__.py
CHANGED
@@ -4,9 +4,6 @@ __codegen__ = True
|
|
4
4
|
|
5
5
|
from nshtrainer import MetricConfig as MetricConfig
|
6
6
|
from nshtrainer import TrainerConfig as TrainerConfig
|
7
|
-
from nshtrainer import accelerator_registry as accelerator_registry
|
8
|
-
from nshtrainer import callback_registry as callback_registry
|
9
|
-
from nshtrainer import plugin_registry as plugin_registry
|
10
7
|
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
11
8
|
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
12
9
|
from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
|
@@ -14,6 +11,7 @@ from nshtrainer._hf_hub import (
|
|
14
11
|
HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
|
15
12
|
)
|
16
13
|
from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
|
14
|
+
from nshtrainer._hf_hub import callback_registry as callback_registry
|
17
15
|
from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
|
18
16
|
from nshtrainer.callbacks import (
|
19
17
|
BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
|
@@ -106,6 +104,8 @@ from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
|
|
106
104
|
from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
|
107
105
|
from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
|
108
106
|
from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
|
107
|
+
from nshtrainer.trainer import accelerator_registry as accelerator_registry
|
108
|
+
from nshtrainer.trainer import plugin_registry as plugin_registry
|
109
109
|
from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
|
110
110
|
from nshtrainer.trainer._config import (
|
111
111
|
CheckpointCallbackConfig as CheckpointCallbackConfig,
|
nshtrainer/data/datamodule.py
CHANGED
@@ -30,9 +30,9 @@ class LightningDataModuleBase(
|
|
30
30
|
|
31
31
|
@property
|
32
32
|
@override
|
33
|
-
def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
|
33
|
+
def hparams_initial(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
|
34
34
|
hparams = cast(THparams, super().hparams_initial)
|
35
|
-
return
|
35
|
+
return hparams
|
36
36
|
|
37
37
|
@property
|
38
38
|
@deprecated("Use `hparams` instead")
|
nshtrainer/model/base.py
CHANGED
@@ -134,16 +134,9 @@ class LightningModuleBase(
|
|
134
134
|
|
135
135
|
@property
|
136
136
|
@override
|
137
|
-
def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
|
137
|
+
def hparams_initial(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
|
138
138
|
hparams = cast(THparams, super().hparams_initial)
|
139
|
-
|
140
|
-
if (trainer := self._trainer) is not None:
|
141
|
-
from ..trainer import Trainer
|
142
|
-
|
143
|
-
if isinstance(trainer, Trainer):
|
144
|
-
hparams_dict["trainer"] = trainer.hparams.model_dump(mode="json")
|
145
|
-
|
146
|
-
return cast(Never, hparams_dict)
|
139
|
+
return hparams
|
147
140
|
|
148
141
|
@property
|
149
142
|
@deprecated("Use `hparams` instead")
|
@@ -0,0 +1,85 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from typing import Any, cast
|
6
|
+
|
7
|
+
import nshconfig as C
|
8
|
+
from lightning.pytorch import LightningDataModule, Trainer
|
9
|
+
|
10
|
+
log = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
def _dict(obj: Any):
|
14
|
+
if isinstance(obj, C.Config):
|
15
|
+
return obj.model_dump(mode="json")
|
16
|
+
|
17
|
+
try:
|
18
|
+
return dict(obj)
|
19
|
+
except Exception:
|
20
|
+
return json.loads(
|
21
|
+
json.dumps(obj, default=lambda o: str(o), indent=4, sort_keys=True)
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
def _dict_and_clean(obj: Any):
|
26
|
+
d = _dict(obj)
|
27
|
+
|
28
|
+
# Remove LightningCLI's internal hparam
|
29
|
+
d = {k: v for k, v in d.items() if k != "_class_path"}
|
30
|
+
return d
|
31
|
+
|
32
|
+
|
33
|
+
def _log_hyperparams(trainer: Trainer) -> None:
|
34
|
+
if not trainer.loggers:
|
35
|
+
return
|
36
|
+
|
37
|
+
hparams_to_log: dict[str, Any] = {}
|
38
|
+
|
39
|
+
from .trainer import Trainer
|
40
|
+
|
41
|
+
if isinstance(trainer, Trainer):
|
42
|
+
hparams_to_log["trainer"] = _dict_and_clean(trainer.hparams)
|
43
|
+
|
44
|
+
if (
|
45
|
+
pl_module := trainer.lightning_module
|
46
|
+
) is not None and pl_module._log_hyperparams:
|
47
|
+
hparams_to_log["model"] = _dict_and_clean(pl_module.hparams_initial)
|
48
|
+
|
49
|
+
if (
|
50
|
+
datamodule := cast(
|
51
|
+
LightningDataModule | None, getattr(trainer, "datamodule", None)
|
52
|
+
)
|
53
|
+
) is not None and (datamodule._log_hyperparams):
|
54
|
+
hparams_to_log["datamodule"] = _dict_and_clean(datamodule.hparams_initial)
|
55
|
+
|
56
|
+
for logger in trainer.loggers:
|
57
|
+
logger.log_hyperparams(hparams_to_log)
|
58
|
+
logger.log_graph(pl_module)
|
59
|
+
logger.save()
|
60
|
+
|
61
|
+
|
62
|
+
def patch_log_hparams_function():
|
63
|
+
try:
|
64
|
+
import lightning.pytorch.loggers.utilities
|
65
|
+
import lightning.pytorch.trainer.trainer
|
66
|
+
|
67
|
+
lightning.pytorch.loggers.utilities._log_hyperparams = _log_hyperparams
|
68
|
+
lightning.pytorch.trainer.trainer._log_hyperparams = _log_hyperparams
|
69
|
+
log.info(
|
70
|
+
"Patched lightning.pytorch's _log_hyperparams to use nshtrainer's version"
|
71
|
+
)
|
72
|
+
except ImportError:
|
73
|
+
pass
|
74
|
+
|
75
|
+
try:
|
76
|
+
import pytorch_lightning.loggers.utilities
|
77
|
+
import pytorch_lightning.trainer.trainer
|
78
|
+
|
79
|
+
pytorch_lightning.loggers.utilities._log_hyperparams = _log_hyperparams
|
80
|
+
pytorch_lightning.trainer.trainer._log_hyperparams = _log_hyperparams
|
81
|
+
log.info(
|
82
|
+
"Patched pytorch_lightning's _log_hyperparams to use nshtrainer's version"
|
83
|
+
)
|
84
|
+
except ImportError:
|
85
|
+
pass
|
nshtrainer/trainer/trainer.py
CHANGED
@@ -23,6 +23,7 @@ from ..callbacks.base import resolve_all_callbacks
|
|
23
23
|
from ..util._environment_info import EnvironmentConfig
|
24
24
|
from ..util.bf16 import is_bf16_supported_no_emulation
|
25
25
|
from ._config import LightningTrainerKwargs, TrainerConfig
|
26
|
+
from ._log_hparams import patch_log_hparams_function
|
26
27
|
from ._runtime_callback import RuntimeTrackerCallback, Stage
|
27
28
|
from .accelerator import AcceleratorConfigBase
|
28
29
|
from .signal_connector import _SignalConnector
|
@@ -31,6 +32,9 @@ from .strategy import StrategyConfigBase
|
|
31
32
|
log = logging.getLogger(__name__)
|
32
33
|
|
33
34
|
|
35
|
+
patch_log_hparams_function()
|
36
|
+
|
37
|
+
|
34
38
|
class Trainer(LightningTrainer):
|
35
39
|
CHECKPOINT_HYPER_PARAMS_KEY = "trainer_hyper_parameters"
|
36
40
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
|
2
|
-
nshtrainer/__init__.py,sha256=
|
2
|
+
nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
|
3
3
|
nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
|
4
4
|
nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
|
5
5
|
nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
|
@@ -32,7 +32,7 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
|
|
32
32
|
nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
|
33
33
|
nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
|
34
34
|
nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
|
35
|
-
nshtrainer/configs/__init__.py,sha256
|
35
|
+
nshtrainer/configs/__init__.py,sha256=4WNs4Zv4PtHWD0KKH4X7j_zFt-COrEB0KhNIljsA6Rc,14740
|
36
36
|
nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
|
37
37
|
nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
|
38
38
|
nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
|
@@ -101,7 +101,7 @@ nshtrainer/configs/util/config/dtype/__init__.py,sha256=PmGF-O4r6SXqEaagVsQ5YxEq
|
|
101
101
|
nshtrainer/configs/util/config/duration/__init__.py,sha256=44lS2irOIPVfgshMTfnZM2jC6l0Pjst9w2M_lJoS_MU,353
|
102
102
|
nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
|
103
103
|
nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
|
104
|
-
nshtrainer/data/datamodule.py,sha256=
|
104
|
+
nshtrainer/data/datamodule.py,sha256=Rb4-mA8iXtjRlNUHcIqVPEvxA_VkiJXwN1EvHIsydJ0,4095
|
105
105
|
nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
|
106
106
|
nshtrainer/loggers/__init__.py,sha256=fI0OHEltHP4tZI-KFB3npdzoxm_M2QsEYKxY3um05_s,592
|
107
107
|
nshtrainer/loggers/actsave.py,sha256=wgNrpBB6wQM7qff8iLDb_sQnbiAcYHRmH56pcEJPB3o,1409
|
@@ -116,7 +116,7 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=irPyDjfUX843ze4bJM9sW8WSe
|
|
116
116
|
nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
|
117
117
|
nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
|
118
118
|
nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
|
119
|
-
nshtrainer/model/base.py,sha256=
|
119
|
+
nshtrainer/model/base.py,sha256=LsOK5mMhYG5J0eSFKZKdd1fTvr38sgi8LLVSqoW6OCU,8386
|
120
120
|
nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
|
121
121
|
nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
|
122
122
|
nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
|
@@ -134,6 +134,7 @@ nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1
|
|
134
134
|
nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
|
135
135
|
nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
|
136
136
|
nshtrainer/trainer/_config.py,sha256=s-_XoLc9mbNAdroRJyOKd3dLTyrFLQkPyGJkKDmBYf8,33267
|
137
|
+
nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
|
137
138
|
nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
|
138
139
|
nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
|
139
140
|
nshtrainer/trainer/plugin/__init__.py,sha256=q_q98MYNaZ2VE_tqGqYlQjQnlaF4NE1FUqVVbj0EK7k,517
|
@@ -144,7 +145,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
|
|
144
145
|
nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
|
145
146
|
nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
|
146
147
|
nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
|
147
|
-
nshtrainer/trainer/trainer.py,sha256=
|
148
|
+
nshtrainer/trainer/trainer.py,sha256=BKRicDlLI7KstzuP0SmzJzp0U4GK5lhZcKHS1IuL5sA,21197
|
148
149
|
nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
|
149
150
|
nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
|
150
151
|
nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
|
@@ -156,6 +157,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
|
|
156
157
|
nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
|
157
158
|
nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
|
158
159
|
nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
|
159
|
-
nshtrainer-1.1.
|
160
|
-
nshtrainer-1.1.
|
161
|
-
nshtrainer-1.1.
|
160
|
+
nshtrainer-1.1.1b1.dist-info/METADATA,sha256=wdOIQ91eUgWrIHfPLP06FD4uMkyyIfToR3VhBY-BXsE,962
|
161
|
+
nshtrainer-1.1.1b1.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
162
|
+
nshtrainer-1.1.1b1.dist-info/RECORD,,
|
File without changes
|