nshtrainer 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/__init__.py CHANGED
@@ -14,9 +14,6 @@ from .metrics import MetricConfig as MetricConfig
14
14
  from .model import LightningModuleBase as LightningModuleBase
15
15
  from .trainer import Trainer as Trainer
16
16
  from .trainer import TrainerConfig as TrainerConfig
17
- from .trainer import accelerator_registry as accelerator_registry
18
- from .trainer import callback_registry as callback_registry
19
- from .trainer import plugin_registry as plugin_registry
20
17
 
21
18
  try:
22
19
  from . import configs as configs
@@ -4,9 +4,6 @@ __codegen__ = True
4
4
 
5
5
  from nshtrainer import MetricConfig as MetricConfig
6
6
  from nshtrainer import TrainerConfig as TrainerConfig
7
- from nshtrainer import accelerator_registry as accelerator_registry
8
- from nshtrainer import callback_registry as callback_registry
9
- from nshtrainer import plugin_registry as plugin_registry
10
7
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
11
8
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
12
9
  from nshtrainer._hf_hub import CallbackConfigBase as CallbackConfigBase
@@ -14,6 +11,7 @@ from nshtrainer._hf_hub import (
14
11
  HuggingFaceHubAutoCreateConfig as HuggingFaceHubAutoCreateConfig,
15
12
  )
16
13
  from nshtrainer._hf_hub import HuggingFaceHubConfig as HuggingFaceHubConfig
14
+ from nshtrainer._hf_hub import callback_registry as callback_registry
17
15
  from nshtrainer.callbacks import ActSaveConfig as ActSaveConfig
18
16
  from nshtrainer.callbacks import (
19
17
  BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
@@ -106,6 +104,8 @@ from nshtrainer.profiler import BaseProfilerConfig as BaseProfilerConfig
106
104
  from nshtrainer.profiler import ProfilerConfig as ProfilerConfig
107
105
  from nshtrainer.profiler import PyTorchProfilerConfig as PyTorchProfilerConfig
108
106
  from nshtrainer.profiler import SimpleProfilerConfig as SimpleProfilerConfig
107
+ from nshtrainer.trainer import accelerator_registry as accelerator_registry
108
+ from nshtrainer.trainer import plugin_registry as plugin_registry
109
109
  from nshtrainer.trainer._config import AcceleratorConfig as AcceleratorConfig
110
110
  from nshtrainer.trainer._config import (
111
111
  CheckpointCallbackConfig as CheckpointCallbackConfig,
@@ -30,9 +30,9 @@ class LightningDataModuleBase(
30
30
 
31
31
  @property
32
32
  @override
33
- def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
33
+ def hparams_initial(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
34
34
  hparams = cast(THparams, super().hparams_initial)
35
- return cast(Never, {"datamodule": hparams.model_dump(mode="json")})
35
+ return hparams
36
36
 
37
37
  @property
38
38
  @deprecated("Use `hparams` instead")
nshtrainer/model/base.py CHANGED
@@ -134,16 +134,9 @@ class LightningModuleBase(
134
134
 
135
135
  @property
136
136
  @override
137
- def hparams_initial(self): # pyright: ignore[reportIncompatibleMethodOverride]
137
+ def hparams_initial(self) -> THparams: # pyright: ignore[reportIncompatibleMethodOverride]
138
138
  hparams = cast(THparams, super().hparams_initial)
139
- hparams_dict = {"model": hparams.model_dump(mode="json")}
140
- if (trainer := self._trainer) is not None:
141
- from ..trainer import Trainer
142
-
143
- if isinstance(trainer, Trainer):
144
- hparams_dict["trainer"] = trainer.hparams.model_dump(mode="json")
145
-
146
- return cast(Never, hparams_dict)
139
+ return hparams
147
140
 
148
141
  @property
149
142
  @deprecated("Use `hparams` instead")
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from typing import Any, cast
6
+
7
+ import nshconfig as C
8
+ from lightning.pytorch import LightningDataModule, Trainer
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ def _dict(obj: Any):
14
+ if isinstance(obj, C.Config):
15
+ return obj.model_dump(mode="json")
16
+
17
+ try:
18
+ return dict(obj)
19
+ except Exception:
20
+ return json.loads(
21
+ json.dumps(obj, default=lambda o: str(o), indent=4, sort_keys=True)
22
+ )
23
+
24
+
25
+ def _dict_and_clean(obj: Any):
26
+ d = _dict(obj)
27
+
28
+ # Remove LightningCLI's internal hparam
29
+ d = {k: v for k, v in d.items() if k != "_class_path"}
30
+ return d
31
+
32
+
33
+ def _log_hyperparams(trainer: Trainer) -> None:
34
+ if not trainer.loggers:
35
+ return
36
+
37
+ hparams_to_log: dict[str, Any] = {}
38
+
39
+ from .trainer import Trainer
40
+
41
+ if isinstance(trainer, Trainer):
42
+ hparams_to_log["trainer"] = _dict_and_clean(trainer.hparams)
43
+
44
+ if (
45
+ pl_module := trainer.lightning_module
46
+ ) is not None and pl_module._log_hyperparams:
47
+ hparams_to_log["model"] = _dict_and_clean(pl_module.hparams_initial)
48
+
49
+ if (
50
+ datamodule := cast(
51
+ LightningDataModule | None, getattr(trainer, "datamodule", None)
52
+ )
53
+ ) is not None and (datamodule._log_hyperparams):
54
+ hparams_to_log["datamodule"] = _dict_and_clean(datamodule.hparams_initial)
55
+
56
+ for logger in trainer.loggers:
57
+ logger.log_hyperparams(hparams_to_log)
58
+ logger.log_graph(pl_module)
59
+ logger.save()
60
+
61
+
62
+ def patch_log_hparams_function():
63
+ try:
64
+ import lightning.pytorch.loggers.utilities
65
+ import lightning.pytorch.trainer.trainer
66
+
67
+ lightning.pytorch.loggers.utilities._log_hyperparams = _log_hyperparams
68
+ lightning.pytorch.trainer.trainer._log_hyperparams = _log_hyperparams
69
+ log.info(
70
+ "Patched lightning.pytorch's _log_hyperparams to use nshtrainer's version"
71
+ )
72
+ except ImportError:
73
+ pass
74
+
75
+ try:
76
+ import pytorch_lightning.loggers.utilities
77
+ import pytorch_lightning.trainer.trainer
78
+
79
+ pytorch_lightning.loggers.utilities._log_hyperparams = _log_hyperparams
80
+ pytorch_lightning.trainer.trainer._log_hyperparams = _log_hyperparams
81
+ log.info(
82
+ "Patched pytorch_lightning's _log_hyperparams to use nshtrainer's version"
83
+ )
84
+ except ImportError:
85
+ pass
@@ -23,6 +23,7 @@ from ..callbacks.base import resolve_all_callbacks
23
23
  from ..util._environment_info import EnvironmentConfig
24
24
  from ..util.bf16 import is_bf16_supported_no_emulation
25
25
  from ._config import LightningTrainerKwargs, TrainerConfig
26
+ from ._log_hparams import patch_log_hparams_function
26
27
  from ._runtime_callback import RuntimeTrackerCallback, Stage
27
28
  from .accelerator import AcceleratorConfigBase
28
29
  from .signal_connector import _SignalConnector
@@ -31,6 +32,9 @@ from .strategy import StrategyConfigBase
31
32
  log = logging.getLogger(__name__)
32
33
 
33
34
 
35
+ patch_log_hparams_function()
36
+
37
+
34
38
  class Trainer(LightningTrainer):
35
39
  CHECKPOINT_HYPER_PARAMS_KEY = "trainer_hyper_parameters"
36
40
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: nshtrainer
3
- Version: 1.1.0
3
+ Version: 1.1.2
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,5 +1,5 @@
1
1
  nshtrainer/.nshconfig.generated.json,sha256=yZd6cn1RhvNNJUgiUTRYut8ofZYvbulnpPG-rZIRhi4,106
2
- nshtrainer/__init__.py,sha256=g_moPnfQxSxFZX5NB9ILQQOJrt4RTRuiFt9N0STIpxM,874
2
+ nshtrainer/__init__.py,sha256=VcqBfL8RgCcZDaY645nxeDmOspqerx4x46wggCMnS0E,692
3
3
  nshtrainer/_callback.py,sha256=ZDppiJ4d65tRXTEWYPZLH_F1xFizdz1pkWJe_sQ5uII,12564
4
4
  nshtrainer/_checkpoint/metadata.py,sha256=Hh5a7OkdknUEbkEwX6vS88-XLEeuVDoR6a3en2uLzQE,5597
5
5
  nshtrainer/_checkpoint/saver.py,sha256=utcrYKSosd04N9m2GIylufO5DO05D90qVU3mvadfApU,1658
@@ -32,7 +32,7 @@ nshtrainer/callbacks/timer.py,sha256=gDcw_K_ikf0bkVgxQ0cDhvvNvz6GLZVLcatuKfh0ORU
32
32
  nshtrainer/callbacks/wandb_upload_code.py,sha256=shV7UtnXgY2bUlXdVrXiaDs0PNLlIt7TzNJkJPkzvzI,2414
33
33
  nshtrainer/callbacks/wandb_watch.py,sha256=VB14Dy5ZRXQ3di0fPv0K_DFJurLhroLPytnuwQBiJFg,3037
34
34
  nshtrainer/configs/.gitattributes,sha256=VeZmarvNEqiRBOHGcllpKm90nL6C8u4tBu7SEm7fj-E,26
35
- nshtrainer/configs/__init__.py,sha256=-rGk9pnRnuz4yKvACGOpY3nkrWnHholqZGk7UP2Vkrc,14716
35
+ nshtrainer/configs/__init__.py,sha256=4WNs4Zv4PtHWD0KKH4X7j_zFt-COrEB0KhNIljsA6Rc,14740
36
36
  nshtrainer/configs/_checkpoint/__init__.py,sha256=6s7Y68StboqscY2G4P_QG443jz5aiym5SjOogIljWLg,342
37
37
  nshtrainer/configs/_checkpoint/metadata/__init__.py,sha256=oOPfYkXTjKgm6pluGsG6V1TPyCEGjsQpHVL-LffSUFQ,290
38
38
  nshtrainer/configs/_directory/__init__.py,sha256=_oO7vM9DhzHSxtZcv86sTi7hZIptnK1gr-AP9mqQ370,386
@@ -101,7 +101,7 @@ nshtrainer/configs/util/config/dtype/__init__.py,sha256=PmGF-O4r6SXqEaagVsQ5YxEq
101
101
  nshtrainer/configs/util/config/duration/__init__.py,sha256=44lS2irOIPVfgshMTfnZM2jC6l0Pjst9w2M_lJoS_MU,353
102
102
  nshtrainer/data/__init__.py,sha256=K4i3Tw4g9EOK2zlMMbidi99y0SyI4w8P7_XUf1n42Ts,260
103
103
  nshtrainer/data/balanced_batch_sampler.py,sha256=r1cBKRXKHD8E1Ax6tj-FUbE-z1qpbO58mQ9VrK9uLnc,5481
104
- nshtrainer/data/datamodule.py,sha256=0M-HjGZQkLG77HXn4ZgLSypnbSjkjTq6GEJwGWe_gbM,4136
104
+ nshtrainer/data/datamodule.py,sha256=Rb4-mA8iXtjRlNUHcIqVPEvxA_VkiJXwN1EvHIsydJ0,4095
105
105
  nshtrainer/data/transform.py,sha256=qd0lIocO59Fk_m90xyOHgFezbymd1mRwly8nbYIfHGc,2263
106
106
  nshtrainer/loggers/__init__.py,sha256=fI0OHEltHP4tZI-KFB3npdzoxm_M2QsEYKxY3um05_s,592
107
107
  nshtrainer/loggers/actsave.py,sha256=wgNrpBB6wQM7qff8iLDb_sQnbiAcYHRmH56pcEJPB3o,1409
@@ -116,7 +116,7 @@ nshtrainer/lr_scheduler/reduce_lr_on_plateau.py,sha256=irPyDjfUX843ze4bJM9sW8WSe
116
116
  nshtrainer/metrics/__init__.py,sha256=Nqkn_jsDf3n5WtfMcnaaEftYjIIT2b-S7rmsB1MOMkU,86
117
117
  nshtrainer/metrics/_config.py,sha256=ox_ScK6V0J9nzIMhEB0qpToNKpt83VVgOVSRFCV-wBc,595
118
118
  nshtrainer/model/__init__.py,sha256=3G-bwPPSRStWdsdwG9-rn0bXcRpEiP1BiQpF_qavtls,97
119
- nshtrainer/model/base.py,sha256=bZMNap0rkxRbAbu2BOHV_6YS2iZZnvy6wVSMOXGa_ZM,8680
119
+ nshtrainer/model/base.py,sha256=LsOK5mMhYG5J0eSFKZKdd1fTvr38sgi8LLVSqoW6OCU,8386
120
120
  nshtrainer/model/mixins/callback.py,sha256=0LPgve4VszHbLipid4mpI1qnnmdGS2spivs0dXLvqHw,3154
121
121
  nshtrainer/model/mixins/debug.py,sha256=ydLuAAaa7M5bX0gougZ5gWuZnvn4Ra9assal3IZ9hq8,2086
122
122
  nshtrainer/model/mixins/logger.py,sha256=7u9fQig-SVFA9RFIB4U0gqJAzruh49mgmXXvZ6VkDUk,11694
@@ -134,6 +134,7 @@ nshtrainer/profiler/pytorch.py,sha256=8K37XvPnCApUpIK8tA2zNMFIaIiTLSoxKQoiyCPBm1
134
134
  nshtrainer/profiler/simple.py,sha256=PimjqcU-JuS-8C0ZGHAdwCxgNLij4x0FH6WXsjBQzZs,1005
135
135
  nshtrainer/trainer/__init__.py,sha256=fQ7gQRlGWX-90TYT0rttkQyvXDCzo7DAvJgr-jX1zsY,316
136
136
  nshtrainer/trainer/_config.py,sha256=s-_XoLc9mbNAdroRJyOKd3dLTyrFLQkPyGJkKDmBYf8,33267
137
+ nshtrainer/trainer/_log_hparams.py,sha256=XH2lZ4U_3AZBhOt91ocsEhdL_NRz35oWvqLCUFDohUs,2389
137
138
  nshtrainer/trainer/_runtime_callback.py,sha256=6F2Gq27Q8OFfN3RtdNC6QRA8ac0LC1hh4DUE3V5WgbI,4217
138
139
  nshtrainer/trainer/accelerator.py,sha256=Bqq-ry7DeCY4zw9_zBvTZiijpA-uUHrDjtbLV652m4M,2415
139
140
  nshtrainer/trainer/plugin/__init__.py,sha256=q_q98MYNaZ2VE_tqGqYlQjQnlaF4NE1FUqVVbj0EK7k,517
@@ -144,7 +145,7 @@ nshtrainer/trainer/plugin/layer_sync.py,sha256=-BbEyWZ063O7tZme7Gdu1lVxK6p1NeuLc
144
145
  nshtrainer/trainer/plugin/precision.py,sha256=7lf7KZd_yFyPmhLApjEIv0pkoDB5zdxi-7in0wRj3z8,5436
145
146
  nshtrainer/trainer/signal_connector.py,sha256=GhfGcSzfaTNhnj2QFkBDq5aT7FqbLMA7eC8SYQs8_8w,10828
146
147
  nshtrainer/trainer/strategy.py,sha256=VPTn5z3zvXTydY8IJchjhjcOfpvtoejnvUkq5E4WTus,1368
147
- nshtrainer/trainer/trainer.py,sha256=Lo3vUo3ooTAjaX2fUYPFSMv5FP7sWfVov0QbA-T5hZ8,21113
148
+ nshtrainer/trainer/trainer.py,sha256=BKRicDlLI7KstzuP0SmzJzp0U4GK5lhZcKHS1IuL5sA,21197
148
149
  nshtrainer/util/_environment_info.py,sha256=MT8mBe6ZolRfKiwU-les1P-lPNPqXpHQcfADrh_A3uY,24629
149
150
  nshtrainer/util/bf16.py,sha256=9QhHZCkYSfYpIcxwAMoXyuh2yTSHBzT-EdLQB297jEs,762
150
151
  nshtrainer/util/config/__init__.py,sha256=Z39JJufSb61Lhn2GfVcv3eFW_eorOrN9-9llDWlnZZM,272
@@ -156,6 +157,6 @@ nshtrainer/util/seed.py,sha256=diMV8iwBKN7Xxt5pELmui-gyqyT80_CZzomrWhNss0k,316
156
157
  nshtrainer/util/slurm.py,sha256=HflkP5iI_r4UHMyPjw9R4dD5AHsJUpcfJw5PLvGYBRM,1603
157
158
  nshtrainer/util/typed.py,sha256=Xt5fUU6zwLKSTLUdenovnKK0N8qUq89Kddz2_XeykVQ,164
158
159
  nshtrainer/util/typing_utils.py,sha256=MjY-CUX9R5Tzat-BlFnQjwl1PQ_W2yZQoXhkYHlJ_VA,442
159
- nshtrainer-1.1.0.dist-info/METADATA,sha256=9prl1vhkmagIgO0WmzfDEknzv4S0a0qDAh-Bj3LxFjI,960
160
- nshtrainer-1.1.0.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
161
- nshtrainer-1.1.0.dist-info/RECORD,,
160
+ nshtrainer-1.1.2.dist-info/METADATA,sha256=q4vkMHw5QYAM7fpevBqt9by1Ons3DPT_vdsbC1AswQg,960
161
+ nshtrainer-1.1.2.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
162
+ nshtrainer-1.1.2.dist-info/RECORD,,