nshtrainer 0.11.13__tar.gz → 0.12.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/PKG-INFO +1 -1
  2. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/pyproject.toml +1 -1
  3. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/__init__.py +3 -1
  4. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +1 -1
  5. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +1 -1
  6. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/early_stopping.py +68 -0
  7. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/config.py +1 -67
  8. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/README.md +0 -0
  9. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/__init__.py +0 -0
  10. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/_checkpoint/loader.py +0 -0
  11. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  12. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
  13. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  14. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  15. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  16. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  17. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  18. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/actsave.py +0 -0
  19. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/base.py +0 -0
  20. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  21. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  22. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  23. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/ema.py +0 -0
  24. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  25. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  26. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/interval.py +0 -0
  27. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  28. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  29. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  30. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  31. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/timer.py +0 -0
  32. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  33. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/data/__init__.py +0 -0
  34. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  35. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/data/transform.py +0 -0
  36. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/__init__.py +0 -0
  37. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/_experimental.py +0 -0
  38. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/actsave.py +0 -0
  39. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/callbacks.py +0 -0
  40. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/config.py +0 -0
  41. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/data.py +0 -0
  42. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/log.py +0 -0
  43. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  44. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/model.py +0 -0
  45. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/nn.py +0 -0
  46. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/optimizer.py +0 -0
  47. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/runner.py +0 -0
  48. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/snapshot.py +0 -0
  49. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/snoop.py +0 -0
  50. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/trainer.py +0 -0
  51. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/typecheck.py +0 -0
  52. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/ll/util.py +0 -0
  53. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  54. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  55. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  56. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  57. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/metrics/__init__.py +0 -0
  58. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/metrics/_config.py +0 -0
  59. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/__init__.py +0 -0
  60. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/base.py +0 -0
  61. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/modules/callback.py +0 -0
  62. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/modules/debug.py +0 -0
  63. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/modules/distributed.py +0 -0
  64. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/modules/logger.py +0 -0
  65. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/modules/profiler.py +0 -0
  66. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  67. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  68. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/nn/__init__.py +0 -0
  69. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/nn/mlp.py +0 -0
  70. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/nn/module_dict.py +0 -0
  71. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/nn/module_list.py +0 -0
  72. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  73. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/optimizer.py +0 -0
  74. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/runner.py +0 -0
  75. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  76. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/trainer/__init__.py +0 -0
  77. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  78. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  79. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  80. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/trainer/trainer.py +0 -0
  81. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/util/_environment_info.py +0 -0
  82. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/util/_useful_types.py +0 -0
  83. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/util/environment.py +0 -0
  84. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/util/seed.py +0 -0
  85. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/util/slurm.py +0 -0
  86. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/util/typed.py +0 -0
  87. {nshtrainer-0.11.13 → nshtrainer-0.12.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.13
3
+ Version: 0.12.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.11.13"
3
+ version = "0.12.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -13,6 +13,7 @@ from .checkpoint import (
13
13
  OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
14
14
  )
15
15
  from .early_stopping import EarlyStopping as EarlyStopping
16
+ from .early_stopping import EarlyStoppingConfig as EarlyStoppingConfig
16
17
  from .ema import EMA as EMA
17
18
  from .ema import EMAConfig as EMAConfig
18
19
  from .finite_checks import FiniteChecksCallback as FiniteChecksCallback
@@ -34,7 +35,8 @@ from .wandb_watch import WandbWatchCallback as WandbWatchCallback
34
35
  from .wandb_watch import WandbWatchConfig as WandbWatchConfig
35
36
 
36
37
  CallbackConfig = Annotated[
37
- ThroughputMonitorConfig
38
+ EarlyStoppingConfig
39
+ | ThroughputMonitorConfig
38
40
  | EpochTimerConfig
39
41
  | PrintTableMetricsConfig
40
42
  | FiniteChecksConfig
@@ -55,7 +55,7 @@ class BestCheckpoint(CheckpointBase[BestCheckpointCallbackConfig]):
55
55
 
56
56
  @override
57
57
  def default_filename(self):
58
- return f"epoch{{epoch:03d}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
58
+ return f"epoch{{epoch}}-step{{step}}-{self._metric_name_normalized}{{{self.metric.validation_monitor}}}"
59
59
 
60
60
  @override
61
61
  def topk_sort_key(self, metadata: CheckpointMetadata):
@@ -28,7 +28,7 @@ class LastCheckpoint(CheckpointBase[LastCheckpointCallbackConfig]):
28
28
 
29
29
  @override
30
30
  def default_filename(self):
31
- return "epoch{epoch:03d}-step{step:07d}"
31
+ return "epoch{epoch}-step{step}"
32
32
 
33
33
  @override
34
34
  def topk_sort_key(self, metadata: CheckpointMetadata):
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import math
3
+ from typing import Literal
3
4
 
4
5
  from lightning.fabric.utilities.rank_zero import _get_rank
5
6
  from lightning.pytorch import Trainer
@@ -7,9 +8,76 @@ from lightning.pytorch.callbacks import EarlyStopping as _EarlyStopping
7
8
  from lightning.pytorch.utilities.rank_zero import rank_prefixed_message
8
9
  from typing_extensions import override
9
10
 
11
+ from .base import CallbackConfigBase
12
+
10
13
  log = logging.getLogger(__name__)
11
14
 
12
15
 
16
+ class EarlyStoppingConfig(CallbackConfigBase):
17
+ name: Literal["early_stopping"] = "early_stopping"
18
+
19
+ monitor: str | None = None
20
+ """
21
+ The metric to monitor for early stopping.
22
+ If None, the primary metric will be used.
23
+ """
24
+
25
+ mode: Literal["min", "max"] | None = None
26
+ """
27
+ The mode for the metric to monitor for early stopping.
28
+ If None, the primary metric mode will be used.
29
+ """
30
+
31
+ patience: int
32
+ """
33
+ Number of epochs with no improvement after which training will be stopped.
34
+ """
35
+
36
+ min_delta: float = 1.0e-8
37
+ """
38
+ Minimum change in the monitored quantity to qualify as an improvement.
39
+ """
40
+
41
+ min_lr: float | None = None
42
+ """
43
+ Minimum learning rate. If the learning rate of the model is less than this value,
44
+ the training will be stopped.
45
+ """
46
+
47
+ strict: bool = True
48
+ """
49
+ Whether to enforce that the monitored quantity must improve by at least `min_delta`
50
+ to qualify as an improvement.
51
+ """
52
+
53
+ @override
54
+ def create_callbacks(self, root_config):
55
+ monitor = self.monitor
56
+ mode = self.mode
57
+ if monitor is None:
58
+ assert mode is None, "If `monitor` is not provided, `mode` must be None."
59
+
60
+ primary_metric = root_config.primary_metric
61
+ if primary_metric is None:
62
+ raise ValueError(
63
+ "No primary metric is set, so `monitor` must be provided in `early_stopping`."
64
+ )
65
+ monitor = primary_metric.validation_monitor
66
+ mode = primary_metric.mode
67
+
68
+ if mode is None:
69
+ mode = "min"
70
+
71
+ yield EarlyStopping(
72
+ monitor=monitor,
73
+ mode=mode,
74
+ patience=self.patience,
75
+ min_delta=self.min_delta,
76
+ min_lr=self.min_lr,
77
+ strict=self.strict,
78
+ )
79
+
80
+
13
81
  class EarlyStopping(_EarlyStopping):
14
82
  def __init__(
15
83
  self,
@@ -38,6 +38,7 @@ from .._checkpoint.loader import CheckpointLoadingConfig
38
38
  from ..callbacks import (
39
39
  BestCheckpointCallbackConfig,
40
40
  CallbackConfig,
41
+ EarlyStoppingConfig,
41
42
  LastCheckpointCallbackConfig,
42
43
  OnExceptionCheckpointCallbackConfig,
43
44
  WandbWatchConfig,
@@ -1043,73 +1044,6 @@ class LightningTrainerKwargs(TypedDict, total=False):
1043
1044
  """
1044
1045
 
1045
1046
 
1046
- class EarlyStoppingConfig(CallbackConfigBase):
1047
- monitor: str | None = None
1048
- """
1049
- The metric to monitor for early stopping.
1050
- If None, the primary metric will be used.
1051
- """
1052
-
1053
- mode: Literal["min", "max"] | None = None
1054
- """
1055
- The mode for the metric to monitor for early stopping.
1056
- If None, the primary metric mode will be used.
1057
- """
1058
-
1059
- patience: int
1060
- """
1061
- Number of epochs with no improvement after which training will be stopped.
1062
- """
1063
-
1064
- min_delta: float = 1.0e-8
1065
- """
1066
- Minimum change in the monitored quantity to qualify as an improvement.
1067
- """
1068
-
1069
- min_lr: float | None = None
1070
- """
1071
- Minimum learning rate. If the learning rate of the model is less than this value,
1072
- the training will be stopped.
1073
- """
1074
-
1075
- strict: bool = True
1076
- """
1077
- Whether to enforce that the monitored quantity must improve by at least `min_delta`
1078
- to qualify as an improvement.
1079
- """
1080
-
1081
- @override
1082
- def create_callbacks(self, root_config: "BaseConfig"):
1083
- from ..callbacks.early_stopping import EarlyStopping
1084
-
1085
- monitor = self.monitor
1086
- mode = self.mode
1087
- if monitor is None:
1088
- assert mode is None, "If `monitor` is not provided, `mode` must be None."
1089
-
1090
- primary_metric = root_config.primary_metric
1091
- if primary_metric is None:
1092
- raise ValueError(
1093
- "No primary metric is set, so `monitor` must be provided in `early_stopping`."
1094
- )
1095
- monitor = primary_metric.validation_monitor
1096
- mode = primary_metric.mode
1097
-
1098
- if mode is None:
1099
- mode = "min"
1100
-
1101
- return [
1102
- EarlyStopping(
1103
- monitor=monitor,
1104
- mode=mode,
1105
- patience=self.patience,
1106
- min_delta=self.min_delta,
1107
- min_lr=self.min_lr,
1108
- strict=self.strict,
1109
- )
1110
- ]
1111
-
1112
-
1113
1047
  class SanityCheckingConfig(C.Config):
1114
1048
  reduce_lr_on_plateau: Literal["disable", "error", "warn"] = "error"
1115
1049
  """
File without changes