nshtrainer 0.10.0__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,7 +10,7 @@ from lightning.pytorch.trainer.states import TrainerFn
10
10
  from typing_extensions import assert_never
11
11
 
12
12
  from ..metrics._config import MetricConfig
13
- from ._checkpoint_metadata import METADATA_PATH_SUFFIX, CheckpointMetadata
13
+ from .metadata import METADATA_PATH_SUFFIX, CheckpointMetadata
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  from ..model.config import BaseConfig
@@ -12,7 +12,7 @@ from ..model._environment import EnvironmentConfig
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from ..model import BaseConfig, LightningModuleBase
15
- from .trainer import Trainer
15
+ from ..trainer.trainer import Trainer
16
16
 
17
17
  log = logging.getLogger(__name__)
18
18
 
@@ -12,7 +12,7 @@ log = logging.getLogger(__name__)
12
12
 
13
13
 
14
14
  class LatestEpochCheckpointCallbackConfig(CallbackConfigBase):
15
- kind: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
15
+ name: Literal["latest_epoch_checkpoint"] = "latest_epoch_checkpoint"
16
16
 
17
17
  dirpath: str | Path | None = None
18
18
  """Directory path to save the checkpoint file."""
@@ -43,7 +43,7 @@ def _convert_string(input_string: str):
43
43
  class ModelCheckpointCallbackConfig(CallbackConfigBase):
44
44
  """Arguments for the ModelCheckpoint callback."""
45
45
 
46
- kind: Literal["model_checkpoint"] = "model_checkpoint"
46
+ name: Literal["model_checkpoint"] = "model_checkpoint"
47
47
 
48
48
  dirpath: str | Path | None = None
49
49
  """
@@ -43,7 +43,7 @@ def _monkey_patch_disable_barrier(trainer: LightningTrainer):
43
43
 
44
44
 
45
45
  class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
46
- kind: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
46
+ name: Literal["on_exception_checkpoint"] = "on_exception_checkpoint"
47
47
 
48
48
  dirpath: str | Path | None = None
49
49
  """Directory path to save the checkpoint file."""
@@ -34,6 +34,7 @@ from lightning.pytorch.strategies.strategy import Strategy
34
34
  from pydantic import DirectoryPath
35
35
  from typing_extensions import Self, TypedDict, TypeVar, override
36
36
 
37
+ from .._checkpoint.loader import CheckpointLoadingConfig
37
38
  from ..callbacks import (
38
39
  CallbackConfig,
39
40
  LatestEpochCheckpointCallbackConfig,
@@ -43,7 +44,6 @@ from ..callbacks import (
43
44
  )
44
45
  from ..callbacks.base import CallbackConfigBase
45
46
  from ..metrics import MetricConfig
46
- from ..trainer._checkpoint_resolver import CheckpointLoadingConfig
47
47
  from ._environment import EnvironmentConfig
48
48
 
49
49
  log = getLogger(__name__)
@@ -71,7 +71,7 @@ class BaseProfilerConfig(C.Config, ABC):
71
71
 
72
72
 
73
73
  class SimpleProfilerConfig(BaseProfilerConfig):
74
- kind: Literal["simple"] = "simple"
74
+ name: Literal["simple"] = "simple"
75
75
 
76
76
  extended: bool = True
77
77
  """
@@ -99,7 +99,7 @@ class SimpleProfilerConfig(BaseProfilerConfig):
99
99
 
100
100
 
101
101
  class AdvancedProfilerConfig(BaseProfilerConfig):
102
- kind: Literal["advanced"] = "advanced"
102
+ name: Literal["advanced"] = "advanced"
103
103
 
104
104
  line_count_restriction: float = 1.0
105
105
  """
@@ -128,7 +128,7 @@ class AdvancedProfilerConfig(BaseProfilerConfig):
128
128
 
129
129
 
130
130
  class PyTorchProfilerConfig(BaseProfilerConfig):
131
- kind: Literal["pytorch"] = "pytorch"
131
+ name: Literal["pytorch"] = "pytorch"
132
132
 
133
133
  group_by_input_shapes: bool = False
134
134
  """Include operator input shapes and group calls by shape."""
@@ -204,7 +204,7 @@ class PyTorchProfilerConfig(BaseProfilerConfig):
204
204
 
205
205
  ProfilerConfig: TypeAlias = Annotated[
206
206
  SimpleProfilerConfig | AdvancedProfilerConfig | PyTorchProfilerConfig,
207
- C.Field(discriminator="kind"),
207
+ C.Field(discriminator="name"),
208
208
  ]
209
209
 
210
210
 
@@ -260,7 +260,7 @@ def _wandb_available():
260
260
 
261
261
 
262
262
  class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
263
- kind: Literal["wandb"] = "wandb"
263
+ name: Literal["wandb"] = "wandb"
264
264
 
265
265
  enabled: bool = C.Field(default_factory=lambda: _wandb_available())
266
266
  """Enable WandB logging."""
@@ -319,7 +319,7 @@ class WandbLoggerConfig(CallbackConfigBase, BaseLoggerConfig):
319
319
 
320
320
 
321
321
  class CSVLoggerConfig(BaseLoggerConfig):
322
- kind: Literal["csv"] = "csv"
322
+ name: Literal["csv"] = "csv"
323
323
 
324
324
  enabled: bool = True
325
325
  """Enable CSV logging."""
@@ -373,7 +373,7 @@ def _tensorboard_available():
373
373
 
374
374
 
375
375
  class TensorboardLoggerConfig(BaseLoggerConfig):
376
- kind: Literal["tensorboard"] = "tensorboard"
376
+ name: Literal["tensorboard"] = "tensorboard"
377
377
 
378
378
  enabled: bool = C.Field(default_factory=lambda: _tensorboard_available())
379
379
  """Enable TensorBoard logging."""
@@ -419,7 +419,7 @@ class TensorboardLoggerConfig(BaseLoggerConfig):
419
419
 
420
420
  LoggerConfig: TypeAlias = Annotated[
421
421
  WandbLoggerConfig | CSVLoggerConfig | TensorboardLoggerConfig,
422
- C.Field(discriminator="kind"),
422
+ C.Field(discriminator="name"),
423
423
  ]
424
424
 
425
425
 
@@ -717,9 +717,9 @@ class DirectoryConfig(C.Config):
717
717
  if (log_dir := logger.log_dir) is not None:
718
718
  return log_dir
719
719
 
720
- # Save to nshtrainer/{id}/log/{logger kind}
720
+ # Save to nshtrainer/{id}/log/{logger name}
721
721
  log_dir = self.resolve_subdirectory(run_id, "log")
722
- log_dir = log_dir / logger.kind
722
+ log_dir = log_dir / logger.name
723
723
  log_dir.mkdir(exist_ok=True)
724
724
 
725
725
  return log_dir
@@ -738,7 +738,7 @@ CheckpointCallbackConfig: TypeAlias = Annotated[
738
738
  ModelCheckpointCallbackConfig
739
739
  | LatestEpochCheckpointCallbackConfig
740
740
  | OnExceptionCheckpointCallbackConfig,
741
- C.Field(discriminator="kind"),
741
+ C.Field(discriminator="name"),
742
742
  ]
743
743
 
744
744
 
@@ -8,7 +8,7 @@ from lightning.pytorch.trainer.connectors.checkpoint_connector import (
8
8
  from lightning.pytorch.trainer.states import TrainerFn
9
9
  from typing_extensions import override
10
10
 
11
- from ._checkpoint_resolver import CheckpointLoadingConfig, _resolve_checkpoint
11
+ from .._checkpoint.loader import CheckpointLoadingConfig, _resolve_checkpoint
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from ..model.config import BaseConfig
@@ -17,6 +17,7 @@ from lightning.pytorch.trainer.states import TrainerFn
17
17
  from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT
18
18
  from typing_extensions import Unpack, assert_never, override
19
19
 
20
+ from .._checkpoint.metadata import _write_checkpoint_metadata
20
21
  from ..callbacks.base import resolve_all_callbacks
21
22
  from ..model.config import (
22
23
  AcceleratorConfigProtocol,
@@ -25,7 +26,6 @@ from ..model.config import (
25
26
  LightningTrainerKwargs,
26
27
  StrategyConfigProtocol,
27
28
  )
28
- from ._checkpoint_metadata import _write_checkpoint_metadata
29
29
  from ._runtime_callback import RuntimeTrackerCallback, Stage
30
30
  from .signal_connector import _SignalConnector
31
31
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.0
3
+ Version: 0.10.2
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,4 +1,6 @@
1
1
  nshtrainer/__init__.py,sha256=39loiLLXbaGiozEsAn8mPHopxaPsek8JsgR9DD2gxtY,583
2
+ nshtrainer/_checkpoint/loader.py,sha256=48flPr1XgQHOgIPaCrRqOEvRuG0SZuV3cQ1vgHLqFqI,11025
3
+ nshtrainer/_checkpoint/metadata.py,sha256=C7je_soYyEbZjiq7p2_pSVFkgcXnz2J2H5sMy8oskx0,3051
2
4
  nshtrainer/_experimental/__init__.py,sha256=2tQIcrWT8U8no_AeBTYnozaTmxN40kuAJdGQ4b-PoWM,120
3
5
  nshtrainer/_experimental/flops/__init__.py,sha256=edo9Ez3LlrnxkNRX9W6YBhPkRPKYGLpkpnl5gx7sEX8,1550
4
6
  nshtrainer/_experimental/flops/flop_counter.py,sha256=-sL0Fy6poXa__hyzUMdZScjPULp4coQELQpPU6p6dXU,25736
@@ -12,11 +14,11 @@ nshtrainer/callbacks/ema.py,sha256=8-WHmKFP3VfnzMviJaIFmVD9xHPqIPmq9NRF5xdu3c8,1
12
14
  nshtrainer/callbacks/finite_checks.py,sha256=AO5fa51uANAjAkeJfTquOjK6W_4RSU5Kky3f5jmAPlQ,2084
13
15
  nshtrainer/callbacks/gradient_skipping.py,sha256=fSJpjgHbztFKz7w3qFuCHZpmbEt9BCLAy-sU0B4xJQI,3474
14
16
  nshtrainer/callbacks/interval.py,sha256=smz5Zl8cN6X6yHKVsMRS2e3SEkzRCP3LvwE1ONvLfaw,8080
15
- nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=p0zeDK3PLWWl485e9o08ywEEARCfuZ5it47tNCtR4ec,2838
17
+ nshtrainer/callbacks/latest_epoch_checkpoint.py,sha256=UnwgGIc2reD7cTnUeIlDHo1LeAkgLEZFNvy2NGvUfRQ,2838
16
18
  nshtrainer/callbacks/log_epoch.py,sha256=fTa_K_Y8A7g09630cG4YkDE6AzSMPkjb9bpPm4gtqos,1120
17
- nshtrainer/callbacks/model_checkpoint.py,sha256=4zYycpXHGRyL4svWLP6GmG3WJs5m3B5PRCOzXC3m_qg,5955
19
+ nshtrainer/callbacks/model_checkpoint.py,sha256=N0raLsHlCVSbO3QU5eNFUXUDqxxW3C73oQwceMnFE_k,5955
18
20
  nshtrainer/callbacks/norm_logging.py,sha256=EWyrfkp8iHjQi9iAAXHxb0xStw2RwkdpKG2_gLarQRA,6281
19
- nshtrainer/callbacks/on_exception_checkpoint.py,sha256=zna_QF_x4HwD7Es5XxrHLDED43NU1GpcDNoL139HEOs,3355
21
+ nshtrainer/callbacks/on_exception_checkpoint.py,sha256=x42BYZ2ejf2rhqPLCmT5nyWKhA9qBEosiV8ZNhhZ6lI,3355
20
22
  nshtrainer/callbacks/print_table.py,sha256=_FdAHhqylWGk4Z0c2FrLFeiMA4jhfA_beZRK_BHpzmE,2837
21
23
  nshtrainer/callbacks/throughput_monitor.py,sha256=4EF3b79HdHiRgBGIFDyD4O1oywb5h1tV8nml7NuuDjU,1845
22
24
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
@@ -50,7 +52,7 @@ nshtrainer/metrics/_config.py,sha256=hWWS4IXENRyH3RmJ7z1Wx1n3Lt1sNMlGOrcU6PW15o0
50
52
  nshtrainer/model/__init__.py,sha256=TbexTxiE20WHYg5q3L88Hysk4LlHeKk_isv33aSBREA,1918
51
53
  nshtrainer/model/_environment.py,sha256=s3JFnigbssFRJTwH33K7DcAYVhLOFCC1OZgFNXJgjuw,22317
52
54
  nshtrainer/model/base.py,sha256=Bmw-t70TydDbE9P0ee-lTibGoUhrCx5Qke-upa7FGVM,17512
53
- nshtrainer/model/config.py,sha256=f8gbTaIi02U8EyooC1vv2ElZfXPgMIAVtU0n-LnkNE4,53187
55
+ nshtrainer/model/config.py,sha256=OsVba02cmEYVf6V-A6ljV7VMAW5XZO6GWNRk8ktUw2o,53177
54
56
  nshtrainer/model/modules/callback.py,sha256=JF59U9-CjJsAIspEhTJbVaGN0wGctZG7UquE3IS7R8A,6408
55
57
  nshtrainer/model/modules/debug.py,sha256=DTVty8cKnzj1GCULRyGx_sWTTsq9NLi30dzqjRTnuCU,1127
56
58
  nshtrainer/model/modules/distributed.py,sha256=ABpR9d-3uBS_fivfy_WYW-dExW6vp5BPaoPQnOudHng,1725
@@ -68,17 +70,15 @@ nshtrainer/runner.py,sha256=6qfE5FBONzD79kVHuWYKEvK0J_Qi5dMBbHQhRMmnIhE,3649
68
70
  nshtrainer/scripts/check_env.py,sha256=IMl6dSqsLYppI0XuCsVq8lK4bYqXwY9KHJkzsShz4Kg,806
69
71
  nshtrainer/scripts/find_packages.py,sha256=FbdlfmAefttFSMfaT0A46a-oHLP_ioaQKihwBfBeWeA,1467
70
72
  nshtrainer/trainer/__init__.py,sha256=P2rmr8oBVTHk-HJHYPcUwWqDEArMbPR4_rPpATbWK3E,40
71
- nshtrainer/trainer/_checkpoint_metadata.py,sha256=dj3g0rUZLWfohIRFAhhLqB4qh1fJsquQ5-EZ0Zbl5ZE,3042
72
- nshtrainer/trainer/_checkpoint_resolver.py,sha256=kfIccBLWAMwn-Bw1pbj3XTXXaCdO_taUEUp3RdwFuLY,11037
73
73
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
74
- nshtrainer/trainer/checkpoint_connector.py,sha256=9DrZliK95BIZfwVFxL06Uf7DbfbQ5UAWd0xckH-LU6U,2125
74
+ nshtrainer/trainer/checkpoint_connector.py,sha256=xoqI2dcPnlNFPPLVIU6dBOvRPC9PtfX5qu__xV1lx0Y,2124
75
75
  nshtrainer/trainer/signal_connector.py,sha256=llwc8pdKAWxREFpjdi14Bpy8rGVMEJsmJx_s2p4gI8E,10689
76
- nshtrainer/trainer/trainer.py,sha256=qHwerfdQUCU21IkWf50d_qZAIHb2d8qOLfqTszBpzks,16784
76
+ nshtrainer/trainer/trainer.py,sha256=n3T9Iz3eaDostxEdjapWImAsVMxyU9WBdhlPl0THX-g,16785
77
77
  nshtrainer/util/environment.py,sha256=_SEtiQ_s5bL5pllUlf96AOUv15kNvCPvocVC13S7mIk,4166
78
78
  nshtrainer/util/seed.py,sha256=HEXgVs-wldByahOysKwq7506OHxdYTEgmP-tDQVAEkQ,287
79
79
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
80
80
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
81
81
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
82
- nshtrainer-0.10.0.dist-info/METADATA,sha256=GslAMAaEXDbMxDd4ijoqjQKYBjb0iAnEGkZ3pAF_sOQ,695
83
- nshtrainer-0.10.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
- nshtrainer-0.10.0.dist-info/RECORD,,
82
+ nshtrainer-0.10.2.dist-info/METADATA,sha256=bLP7xa9qV9BqgqGZr2RayfaL2eFB9eBSTHVlMpslm5s,695
83
+ nshtrainer-0.10.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
84
+ nshtrainer-0.10.2.dist-info/RECORD,,