nshtrainer 0.11.1__tar.gz → 0.11.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/PKG-INFO +1 -1
  2. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/pyproject.toml +1 -1
  3. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/_checkpoint/loader.py +62 -0
  4. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/_checkpoint/metadata.py +16 -29
  5. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/__init__.py +14 -12
  6. nshtrainer-0.11.3/src/nshtrainer/callbacks/checkpoint/__init__.py +16 -0
  7. nshtrainer-0.11.3/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +155 -0
  8. {nshtrainer-0.11.1/src/nshtrainer/callbacks → nshtrainer-0.11.3/src/nshtrainer/callbacks/checkpoint}/latest_epoch_checkpoint.py +12 -9
  9. {nshtrainer-0.11.1/src/nshtrainer/callbacks → nshtrainer-0.11.3/src/nshtrainer/callbacks/checkpoint}/model_checkpoint.py +19 -8
  10. {nshtrainer-0.11.1/src/nshtrainer/callbacks → nshtrainer-0.11.3/src/nshtrainer/callbacks/checkpoint}/on_exception_checkpoint.py +1 -3
  11. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/metrics/_config.py +5 -0
  12. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/config.py +4 -1
  13. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/trainer/checkpoint_connector.py +2 -2
  14. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/trainer/trainer.py +4 -0
  15. nshtrainer-0.11.3/src/nshtrainer/util/_useful_types.py +307 -0
  16. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/README.md +0 -0
  17. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/__init__.py +0 -0
  18. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/_checkpoint/saver.py +0 -0
  19. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/_experimental/__init__.py +0 -0
  20. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  21. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  22. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  23. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  24. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/actsave.py +0 -0
  25. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/base.py +0 -0
  26. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  27. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/ema.py +0 -0
  28. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  29. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  30. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/interval.py +0 -0
  31. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  32. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  33. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/print_table.py +0 -0
  34. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  35. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/timer.py +0 -0
  36. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  37. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/data/__init__.py +0 -0
  38. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  39. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/data/transform.py +0 -0
  40. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/__init__.py +0 -0
  41. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/_experimental.py +0 -0
  42. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/actsave.py +0 -0
  43. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/callbacks.py +0 -0
  44. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/config.py +0 -0
  45. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/data.py +0 -0
  46. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/log.py +0 -0
  47. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  48. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/model.py +0 -0
  49. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/nn.py +0 -0
  50. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/optimizer.py +0 -0
  51. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/runner.py +0 -0
  52. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/snapshot.py +0 -0
  53. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/snoop.py +0 -0
  54. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/trainer.py +0 -0
  55. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/typecheck.py +0 -0
  56. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/ll/util.py +0 -0
  57. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  58. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  59. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  60. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  61. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/metrics/__init__.py +0 -0
  62. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/__init__.py +0 -0
  63. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/base.py +0 -0
  64. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/modules/callback.py +0 -0
  65. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/modules/debug.py +0 -0
  66. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/modules/distributed.py +0 -0
  67. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/modules/logger.py +0 -0
  68. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/modules/profiler.py +0 -0
  69. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  70. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  71. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/nn/__init__.py +0 -0
  72. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/nn/mlp.py +0 -0
  73. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/nn/module_dict.py +0 -0
  74. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/nn/module_list.py +0 -0
  75. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/nn/nonlinearity.py +0 -0
  76. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/optimizer.py +0 -0
  77. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/runner.py +0 -0
  78. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/scripts/find_packages.py +0 -0
  79. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/trainer/__init__.py +0 -0
  80. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  81. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/trainer/signal_connector.py +0 -0
  82. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/util/_environment_info.py +0 -0
  83. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/util/environment.py +0 -0
  84. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/util/seed.py +0 -0
  85. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/util/slurm.py +0 -0
  86. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/util/typed.py +0 -0
  87. {nshtrainer-0.11.1 → nshtrainer-0.11.3}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.11.1
3
+ Version: 0.11.3
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.11.1"
3
+ version = "0.11.3"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -133,6 +133,68 @@ class CheckpointLoadingConfig(C.Config):
133
133
  ckpt: Literal["best", "last"] | str | Path | None,
134
134
  trainer_mode: TrainerFn,
135
135
  ):
136
+ """
137
+ Automatically create a CheckpointLoadingConfig based on the provided checkpoint option and trainer mode.
138
+
139
+ This method provides a convenient way to generate a checkpoint loading configuration
140
+ tailored to different training and evaluation scenarios.
141
+
142
+ Parameters:
143
+ -----------
144
+ ckpt : Literal["best", "last"] | str | Path | None
145
+ Specifies the checkpoint loading preference:
146
+ - "best": Use the best checkpoint based on the primary metric.
147
+ - "last": Use the most recent checkpoint.
148
+ - str or Path: Path to a specific checkpoint file.
149
+ - None: Defaults to "last" for training, raises an error for evaluation.
150
+
151
+ trainer_mode : TrainerFn
152
+ The mode in which the trainer is operating. This affects how the configuration is created.
153
+ - TrainerFn.FITTING: Used for training scenarios.
154
+ - TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING: Used for evaluation scenarios.
155
+
156
+ Returns:
157
+ --------
158
+ CheckpointLoadingConfig
159
+ A configuration object for checkpoint loading based on the given parameters.
160
+
161
+ Behavior:
162
+ ---------
163
+ 1. For training (TrainerFn.FITTING):
164
+ - Includes HPC pre-emption checkpoints.
165
+ - If ckpt is None, defaults to "last".
166
+ - For "best" or "last", creates a single-strategy configuration that loads the best or last checkpoint.
167
+ - For a specific path, creates a two-strategy configuration:
168
+ a) Tries to load the checkpoint as the last checkpoint.
169
+ b) Falls back to loading it as a user-provided path.
170
+
171
+ 2. For evaluation (VALIDATING, TESTING, PREDICTING):
172
+ - Does not include HPC pre-emption checkpoints.
173
+ - Requires ckpt to be specified (raises ValueError if None).
174
+ - Creates a single-strategy configuration based on the ckpt value.
175
+
176
+ Raises:
177
+ -------
178
+ ValueError
179
+ If ckpt is None during evaluation modes.
180
+
181
+ Examples:
182
+ ---------
183
+ # Training mode, use last checkpoint
184
+ config = CheckpointLoadingConfig.auto("last", TrainerFn.FITTING)
185
+
186
+ # Evaluation mode, use best checkpoint
187
+ config = CheckpointLoadingConfig.auto("best", TrainerFn.TESTING)
188
+
189
+ # Training mode, use specific checkpoint
190
+ config = CheckpointLoadingConfig.auto("/path/to/checkpoint.ckpt", TrainerFn.FITTING)
191
+
192
+ Notes:
193
+ ------
194
+ - The method internally calls _auto_train or _auto_eval based on the trainer_mode.
195
+ - The resulting configuration always includes strategies as a sequence, even if there's only one strategy.
196
+ """
197
+ # Implementation remains the same...
136
198
  match trainer_mode:
137
199
  case TrainerFn.FITTING:
138
200
  return cls._auto_train(ckpt)
@@ -43,6 +43,16 @@ class CheckpointMetadata(C.Config):
43
43
  def from_file(cls, path: Path):
44
44
  return cls.model_validate_json(path.read_text())
45
45
 
46
+ @classmethod
47
+ def from_ckpt_path(cls, checkpoint_path: Path):
48
+ if not (
49
+ metadata_path := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)
50
+ ).exists():
51
+ raise FileNotFoundError(
52
+ f"Metadata file not found for checkpoint: {checkpoint_path}"
53
+ )
54
+ return cls.from_file(metadata_path)
55
+
46
56
 
47
57
  def _generate_checkpoint_metadata(
48
58
  config: "BaseConfig", trainer: "Trainer", checkpoint_path: Path
@@ -136,36 +146,13 @@ def _link_checkpoint_metadata(checkpoint_path: Path, linked_checkpoint_path: Pat
136
146
  log.debug(f"Linked {path} to {linked_path}")
137
147
 
138
148
 
139
- def _checkpoint_sort_key_fn(key: Callable[[CheckpointMetadata, Path], Any]):
140
- def sort_key_fn(checkpoint_path: Path):
141
- if not (p := checkpoint_path.with_suffix(METADATA_PATH_SUFFIX)).exists():
142
- raise FileNotFoundError(f"Metadata file not found: {p}")
143
-
144
- nonlocal key
145
- return key(CheckpointMetadata.from_file(p), p)
146
-
147
- return sort_key_fn
148
-
149
-
150
149
  def _sort_ckpts_by_metadata(
151
150
  checkpoint_paths: list[Path],
152
151
  key: Callable[[CheckpointMetadata, Path], Any],
153
- fallback_key: Callable[[Path], Any],
152
+ reverse: bool = False,
154
153
  ):
155
- # First, let's make sure all the metadata files exist.
156
- # If not, use the fallback function to sort the checkpoints.
157
- no_metadata_paths: list[Path] = []
158
- for path in checkpoint_paths:
159
- if (path.with_suffix(METADATA_PATH_SUFFIX)).exists():
160
- continue
161
-
162
- no_metadata_paths.append(path)
163
-
164
- if no_metadata_paths:
165
- log.warning(
166
- f"Metadata file not found on {len(no_metadata_paths)} checkpoints: {no_metadata_paths}\n"
167
- "Falling back to sorting by last modified time."
168
- )
169
- return sorted(checkpoint_paths, key=fallback_key)
170
-
171
- return sorted(checkpoint_paths, key=_checkpoint_sort_key_fn(key))
154
+ return sorted(
155
+ [(CheckpointMetadata.from_ckpt_path(path), path) for path in checkpoint_paths],
156
+ key=lambda args_tuple: key(*args_tuple),
157
+ reverse=reverse,
158
+ )
@@ -2,7 +2,20 @@ from typing import Annotated
2
2
 
3
3
  import nshconfig as C
4
4
 
5
+ from . import checkpoint as checkpoint
5
6
  from .base import CallbackConfigBase as CallbackConfigBase
7
+ from .checkpoint import BestCheckpoint as BestCheckpoint
8
+ from .checkpoint import BestCheckpointCallbackConfig as BestCheckpointCallbackConfig
9
+ from .checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
10
+ from .checkpoint import (
11
+ LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
12
+ )
13
+ from .checkpoint import ModelCheckpoint as ModelCheckpoint
14
+ from .checkpoint import ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig
15
+ from .checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
16
+ from .checkpoint import (
17
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
18
+ )
6
19
  from .early_stopping import EarlyStopping as EarlyStopping
7
20
  from .ema import EMA as EMA
8
21
  from .ema import EMAConfig as EMAConfig
@@ -13,21 +26,9 @@ from .gradient_skipping import GradientSkippingConfig as GradientSkippingConfig
13
26
  from .interval import EpochIntervalCallback as EpochIntervalCallback
14
27
  from .interval import IntervalCallback as IntervalCallback
15
28
  from .interval import StepIntervalCallback as StepIntervalCallback
16
- from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
17
- from .latest_epoch_checkpoint import (
18
- LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
19
- )
20
29
  from .log_epoch import LogEpochCallback as LogEpochCallback
21
- from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
22
- from .model_checkpoint import (
23
- ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
24
- )
25
30
  from .norm_logging import NormLoggingCallback as NormLoggingCallback
26
31
  from .norm_logging import NormLoggingConfig as NormLoggingConfig
27
- from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
28
- from .on_exception_checkpoint import (
29
- OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
30
- )
31
32
  from .print_table import PrintTableMetricsCallback as PrintTableMetricsCallback
32
33
  from .print_table import PrintTableMetricsConfig as PrintTableMetricsConfig
33
34
  from .throughput_monitor import ThroughputMonitorConfig as ThroughputMonitorConfig
@@ -44,6 +45,7 @@ CallbackConfig = Annotated[
44
45
  | NormLoggingConfig
45
46
  | GradientSkippingConfig
46
47
  | EMAConfig
48
+ | BestCheckpointCallbackConfig
47
49
  | ModelCheckpointCallbackConfig
48
50
  | LatestEpochCheckpointCallbackConfig
49
51
  | OnExceptionCheckpointCallbackConfig
@@ -0,0 +1,16 @@
1
+ from .best_checkpoint import BestCheckpoint as BestCheckpoint
2
+ from .best_checkpoint import (
3
+ BestCheckpointCallbackConfig as BestCheckpointCallbackConfig,
4
+ )
5
+ from .latest_epoch_checkpoint import LatestEpochCheckpoint as LatestEpochCheckpoint
6
+ from .latest_epoch_checkpoint import (
7
+ LatestEpochCheckpointCallbackConfig as LatestEpochCheckpointCallbackConfig,
8
+ )
9
+ from .model_checkpoint import ModelCheckpoint as ModelCheckpoint
10
+ from .model_checkpoint import (
11
+ ModelCheckpointCallbackConfig as ModelCheckpointCallbackConfig,
12
+ )
13
+ from .on_exception_checkpoint import OnExceptionCheckpoint as OnExceptionCheckpoint
14
+ from .on_exception_checkpoint import (
15
+ OnExceptionCheckpointCallbackConfig as OnExceptionCheckpointCallbackConfig,
16
+ )
@@ -0,0 +1,155 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Any, Literal
4
+
5
+ from lightning.pytorch import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Checkpoint
7
+ from typing_extensions import override
8
+
9
+ from ..._checkpoint.metadata import _sort_ckpts_by_metadata
10
+ from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
+ from ...metrics._config import MetricConfig
12
+ from ..base import CallbackConfigBase
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class BestCheckpointCallbackConfig(CallbackConfigBase):
18
+ name: Literal["best_checkpoint"] = "best_checkpoint"
19
+
20
+ dirpath: str | Path | None = None
21
+ """Directory path to save the checkpoint file."""
22
+
23
+ filename: str = "epoch{epoch:02d}_step{step:04d}"
24
+ """Checkpoint filename. This must not include the extension."""
25
+
26
+ save_weights_only: bool = False
27
+ """Whether to save only the model's weights or the entire model object."""
28
+
29
+ metric: MetricConfig | None = None
30
+ """Metric to monitor, or `None` to use the default metric."""
31
+
32
+ best_symlink_filename: str | None = "best"
33
+ """Filename for the best symlink. If None, no symlink will be created."""
34
+
35
+ save_top_k: int | Literal["all"] = 1
36
+ """The number of best checkpoints to keep."""
37
+
38
+ @override
39
+ def create_callbacks(self, root_config):
40
+ dirpath = Path(
41
+ self.dirpath
42
+ or root_config.directory.resolve_subdirectory(root_config.id, "checkpoint")
43
+ )
44
+
45
+ # Resolve metric
46
+ if (metric := self.metric) is None and (
47
+ metric := root_config.primary_metric
48
+ ) is None:
49
+ raise ValueError(
50
+ "No metric provided and no primary metric found in the root config"
51
+ )
52
+
53
+ yield BestCheckpoint(self, metric, dirpath)
54
+
55
+ @property
56
+ def _save_top_k_value(self):
57
+ return float("inf" if self.save_top_k == "all" else self.save_top_k)
58
+
59
+
60
+ class BestCheckpoint(Checkpoint):
61
+ PREFIX = "best_"
62
+ EXTENSION = ".ckpt"
63
+
64
+ def __init__(
65
+ self,
66
+ config: BestCheckpointCallbackConfig,
67
+ metric: MetricConfig,
68
+ dirpath: Path,
69
+ ):
70
+ super().__init__()
71
+ self.config = config
72
+ self.metric = metric
73
+ self.dirpath = dirpath
74
+
75
+ @override
76
+ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
77
+ self._save_best_checkpoint(trainer)
78
+
79
+ def _best_symlink_filename(self):
80
+ if (filename := self.config.best_symlink_filename) is None:
81
+ return None
82
+ return f"{filename}{self.EXTENSION}"
83
+
84
+ def _ckpt_path(self, trainer: Trainer):
85
+ filename = self.config.filename.format(
86
+ epoch=trainer.current_epoch, step=trainer.global_step
87
+ )
88
+ filename = f"{self.PREFIX}{filename}{self.EXTENSION}"
89
+ return self.dirpath / filename
90
+
91
+ def _remove_checkpoints(self, trainer: Trainer, ckpt_paths: list[Path]):
92
+ for ckpt_path in ckpt_paths:
93
+ _remove_checkpoint(trainer, ckpt_path, metadata=True, barrier=False)
94
+
95
+ def _get_metric_value(self, metrics: dict[str, Any]):
96
+ return metrics.get(
97
+ self.metric.validation_monitor,
98
+ float("-inf" if self.metric.mode == "max" else "inf"),
99
+ )
100
+
101
+ def _sorted_ckpts(self):
102
+ ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
103
+ return _sort_ckpts_by_metadata(
104
+ ckpt_paths,
105
+ key=lambda meta, _: self._get_metric_value(meta.metrics),
106
+ reverse=(self.metric.mode == "min"),
107
+ )
108
+
109
+ def _save_best_checkpoint(self, trainer: Trainer):
110
+ if (current := self._get_metric_value(trainer.callback_metrics)) is None:
111
+ log.warning(
112
+ f"Can't save best model, {self.metric.validation_monitor} not found in metrics"
113
+ )
114
+ return
115
+
116
+ # Get sorted checkpoints
117
+ sorted_ckpts = self._sorted_ckpts()
118
+
119
+ # If the current model is worse than the worst checkpoint,
120
+ # and we have already saved the maximum number of checkpoints,
121
+ # then don't save the current model.
122
+ if len(
123
+ sorted_ckpts
124
+ ) >= self.config._save_top_k_value and not self.metric.is_better(
125
+ current,
126
+ self._get_metric_value(sorted_ckpts[-1][0].metrics),
127
+ ):
128
+ return
129
+
130
+ # Save the current model
131
+ filepath = self._ckpt_path(trainer)
132
+ trainer.save_checkpoint(filepath, self.config.save_weights_only)
133
+
134
+ # Remove worst checkpoint if we've reached save_top_k
135
+ # NOTE: We add 1 to save_top_k here because we have just saved a new checkpoint
136
+ if len(sorted_ckpts) + 1 > self.config._save_top_k_value:
137
+ # Get the sorted checkpoints again because now we have added a new checkpoint.
138
+ # We could optimize this by adding the new checkpoint to the sorted list,
139
+ # and then sorting it in place, but this is simpler.
140
+ sorted_ckpts = self._sorted_ckpts()
141
+ self._remove_checkpoints(
142
+ trainer, [p for _, p in sorted_ckpts[self.config.save_top_k :]]
143
+ )
144
+
145
+ # Create symlink to best model
146
+ if (symlink_filename := self._best_symlink_filename()) is not None:
147
+ symlink_path = self.dirpath / symlink_filename
148
+ _link_checkpoint(
149
+ trainer,
150
+ filepath,
151
+ symlink_path,
152
+ barrier=True,
153
+ metadata=True,
154
+ )
155
+ log.debug(f"Created best symlink: {symlink_path}")
@@ -6,9 +6,9 @@ from lightning.pytorch import LightningModule, Trainer
6
6
  from lightning.pytorch.callbacks import Checkpoint
7
7
  from typing_extensions import override
8
8
 
9
- from .._checkpoint.metadata import _sort_ckpts_by_metadata
10
- from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
- from .base import CallbackConfigBase
9
+ from ..._checkpoint.metadata import _sort_ckpts_by_metadata
10
+ from ..._checkpoint.saver import _link_checkpoint, _remove_checkpoint
11
+ from ..base import CallbackConfigBase
12
12
 
13
13
  log = logging.getLogger(__name__)
14
14
 
@@ -75,6 +75,10 @@ class LatestEpochCheckpoint(Checkpoint):
75
75
  if (latest_k := self.config.latest_k) == "all":
76
76
  return
77
77
 
78
+ # NOTE: We add 1 to the latest_k here because
79
+ # we're about to save a new checkpoint.
80
+ latest_k += 1
81
+
78
82
  # Get all configs, ignoring the latest symlink
79
83
  ckpt_paths = list(self.dirpath.glob(f"{self.PREFIX}*{self.EXTENSION}"))
80
84
  # Ignore the latest symlink
@@ -82,16 +86,15 @@ class LatestEpochCheckpoint(Checkpoint):
82
86
  ckpt_paths = [p for p in ckpt_paths if p.name != latest_symlink_filename]
83
87
 
84
88
  # Sort by epoch, then step, then last modified
85
- ckpt_paths = _sort_ckpts_by_metadata(
89
+ metadata_and_ckpt_paths = _sort_ckpts_by_metadata(
86
90
  ckpt_paths,
87
91
  key=lambda meta, p: (meta.epoch, meta.global_step, p.stat().st_mtime),
88
- fallback_key=lambda p: p.stat().st_mtime,
89
- # ^ Called if metadata is not found on all checkpoints
92
+ reverse=True,
90
93
  )
91
94
 
92
95
  # Remove all but the latest k checkpoints
93
- ckpts_to_remove = ckpt_paths[:-latest_k]
94
- self._remove_checkpoints(trainer, ckpts_to_remove)
96
+ ckpts_to_remove = metadata_and_ckpt_paths[latest_k:]
97
+ self._remove_checkpoints(trainer, [p for _, p in ckpts_to_remove])
95
98
 
96
99
  def _save_new_checkpoint(self, trainer: Trainer):
97
100
  # Remove old checkpoints
@@ -113,4 +116,4 @@ class LatestEpochCheckpoint(Checkpoint):
113
116
  barrier=True,
114
117
  metadata=True,
115
118
  )
116
- log.info(f"Created latest symlink: {symlink_path}")
119
+ log.debug(f"Created latest symlink: {symlink_path}")
@@ -10,12 +10,13 @@ from lightning.pytorch.callbacks.model_checkpoint import (
10
10
  )
11
11
  from typing_extensions import override
12
12
 
13
- from .._checkpoint.saver import _link_checkpoint, _remove_checkpoint
14
- from ..metrics import MetricConfig
15
- from .base import CallbackConfigBase
13
+ from ..._checkpoint.saver import _link_checkpoint
14
+ from ..._checkpoint.saver import _remove_checkpoint as _ckpt_saver_remove_checkpoint
15
+ from ...metrics import MetricConfig
16
+ from ..base import CallbackConfigBase
16
17
 
17
18
  if TYPE_CHECKING:
18
- from ..model.config import BaseConfig
19
+ from ...model.config import BaseConfig
19
20
 
20
21
  log = logging.getLogger(__name__)
21
22
 
@@ -74,10 +75,10 @@ class ModelCheckpointCallbackConfig(CallbackConfigBase):
74
75
  If "link", creates a symbolic link to the last checkpoint.
75
76
  """
76
77
 
77
- save_top_k: int = 1
78
+ save_top_k: int | Literal["all"] = 1
78
79
  """
79
80
  Number of best models to save.
80
- If -1, all models are saved.
81
+ If "all" or -1, all models are saved.
81
82
  If 0, no models are saved.
82
83
  """
83
84
 
@@ -158,6 +159,11 @@ class ModelCheckpointCallbackConfig(CallbackConfigBase):
158
159
  metric=metric,
159
160
  )
160
161
 
162
+ def _save_top_k_model_ckpt_input(self):
163
+ if self.save_top_k == "all":
164
+ return -1
165
+ return self.save_top_k
166
+
161
167
 
162
168
  class ModelCheckpoint(_ModelCheckpoint):
163
169
  CHECKPOINT_NAME_LAST = "best"
@@ -180,7 +186,7 @@ class ModelCheckpoint(_ModelCheckpoint):
180
186
  mode=metric.mode,
181
187
  verbose=self.config.verbose,
182
188
  save_last=self.config.save_last,
183
- save_top_k=self.config.save_top_k,
189
+ save_top_k=self.config._save_top_k_model_ckpt_input(),
184
190
  save_weights_only=self.config.save_weights_only,
185
191
  auto_insert_metric_name=False,
186
192
  every_n_train_steps=self.config.every_n_train_steps,
@@ -202,4 +208,9 @@ class ModelCheckpoint(_ModelCheckpoint):
202
208
 
203
209
  @override
204
210
  def _remove_checkpoint(self, trainer: Trainer, filepath: str):
205
- return _remove_checkpoint(trainer, filepath, metadata=True, barrier=False)
211
+ return _ckpt_saver_remove_checkpoint(
212
+ trainer,
213
+ filepath,
214
+ metadata=True,
215
+ barrier=False,
216
+ )
@@ -9,7 +9,7 @@ from lightning.pytorch import Trainer as LightningTrainer
9
9
  from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
10
10
  from typing_extensions import override
11
11
 
12
- from .base import CallbackConfigBase
12
+ from ..base import CallbackConfigBase
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
@@ -53,8 +53,6 @@ class OnExceptionCheckpointCallbackConfig(CallbackConfigBase):
53
53
 
54
54
  @override
55
55
  def create_callbacks(self, root_config):
56
- from ..callbacks.on_exception_checkpoint import OnExceptionCheckpoint
57
-
58
56
  dirpath = self.dirpath or root_config.directory.resolve_subdirectory(
59
57
  root_config.id, "checkpoint"
60
58
  )
@@ -3,6 +3,8 @@ from typing import Literal
3
3
 
4
4
  import nshconfig as C
5
5
 
6
+ from ..util._useful_types import SupportsRichComparisonT
7
+
6
8
 
7
9
  class MetricConfig(C.Config):
8
10
  name: str
@@ -35,3 +37,6 @@ class MetricConfig(C.Config):
35
37
  @property
36
38
  def best(self):
37
39
  return builtins.min if self.mode == "min" else builtins.max
40
+
41
+ def is_better(self, a: SupportsRichComparisonT, b: SupportsRichComparisonT) -> bool:
42
+ return self.best(a, b) == a
@@ -37,6 +37,7 @@ from typing_extensions import Self, TypedDict, TypeVar, override
37
37
 
38
38
  from .._checkpoint.loader import CheckpointLoadingConfig
39
39
  from ..callbacks import (
40
+ BestCheckpointCallbackConfig,
40
41
  CallbackConfig,
41
42
  LatestEpochCheckpointCallbackConfig,
42
43
  ModelCheckpointCallbackConfig,
@@ -771,6 +772,7 @@ class ReproducibilityConfig(C.Config):
771
772
 
772
773
  CheckpointCallbackConfig: TypeAlias = Annotated[
773
774
  ModelCheckpointCallbackConfig
775
+ | BestCheckpointCallbackConfig
774
776
  | LatestEpochCheckpointCallbackConfig
775
777
  | OnExceptionCheckpointCallbackConfig,
776
778
  C.Field(discriminator="name"),
@@ -782,7 +784,8 @@ class CheckpointSavingConfig(CallbackConfigBase):
782
784
  """Enable checkpoint saving."""
783
785
 
784
786
  checkpoint_callbacks: Sequence[CheckpointCallbackConfig] = [
785
- ModelCheckpointCallbackConfig(),
787
+ # ModelCheckpointCallbackConfig(),
788
+ BestCheckpointCallbackConfig(),
786
789
  LatestEpochCheckpointCallbackConfig(),
787
790
  OnExceptionCheckpointCallbackConfig(),
788
791
  ]
@@ -3,7 +3,7 @@ from pathlib import Path
3
3
  from typing import TYPE_CHECKING, cast
4
4
 
5
5
  from lightning.pytorch.trainer.connectors.checkpoint_connector import (
6
- _CheckpointConnector,
6
+ _CheckpointConnector as _LightningCheckpointConnector,
7
7
  )
8
8
  from lightning.pytorch.trainer.states import TrainerFn
9
9
  from typing_extensions import override
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
15
15
  log = logging.getLogger(__name__)
16
16
 
17
17
 
18
- class CheckpointConnector(_CheckpointConnector):
18
+ class _CheckpointConnector(_LightningCheckpointConnector):
19
19
  def __resolve_auto_ckpt_path(
20
20
  self,
21
21
  ckpt_path: str | Path | None,
@@ -26,6 +26,7 @@ from ..model.config import (
26
26
  StrategyConfigProtocol,
27
27
  )
28
28
  from ._runtime_callback import RuntimeTrackerCallback, Stage
29
+ from .checkpoint_connector import _CheckpointConnector
29
30
  from .signal_connector import _SignalConnector
30
31
 
31
32
  log = logging.getLogger(__name__)
@@ -297,6 +298,9 @@ class Trainer(LightningTrainer):
297
298
  # Replace the signal connector with our own.
298
299
  self._signal_connector = _SignalConnector(self)
299
300
 
301
+ # Replace the checkpoint connector with our own.
302
+ self._checkpoint_connector = _CheckpointConnector(self)
303
+
300
304
  # Print out the log dir, so that we can easily find it in the logs.
301
305
  if log_dir := self.log_dir:
302
306
  log_dir = str(Path(log_dir).resolve())
@@ -0,0 +1,307 @@
1
+ """Credit to useful-types from https://github.com/hauntsaninja/useful_types"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Awaitable, Iterable, Iterator, Sequence, Sized
6
+ from collections.abc import Set as AbstractSet
7
+ from os import PathLike
8
+ from typing import Any, TypeVar, overload
9
+
10
+ from typing_extensions import Buffer, Literal, Protocol, SupportsIndex, TypeAlias
11
+
12
+ _KT = TypeVar("_KT")
13
+ _KT_co = TypeVar("_KT_co", covariant=True)
14
+ _KT_contra = TypeVar("_KT_contra", contravariant=True)
15
+ _VT = TypeVar("_VT")
16
+ _VT_co = TypeVar("_VT_co", covariant=True)
17
+ _T = TypeVar("_T")
18
+ _T_co = TypeVar("_T_co", covariant=True)
19
+ _T_contra = TypeVar("_T_contra", contravariant=True)
20
+
21
+ # For partially known annotations. Usually, fields where type annotations
22
+ # haven't been added are left unannotated, but in some situations this
23
+ # isn't possible or a type is already partially known. In cases like these,
24
+ # use Incomplete instead of Any as a marker. For example, use
25
+ # "Incomplete | None" instead of "Any | None".
26
+ Incomplete: TypeAlias = Any
27
+
28
+
29
+ class IdentityFunction(Protocol):
30
+ def __call__(self, __x: _T) -> _T: ...
31
+
32
+
33
+ # ====================
34
+ # Comparison protocols
35
+ # ====================
36
+
37
+
38
+ class SupportsDunderLT(Protocol[_T_contra]):
39
+ def __lt__(self, __other: _T_contra) -> bool: ...
40
+
41
+
42
+ class SupportsDunderGT(Protocol[_T_contra]):
43
+ def __gt__(self, __other: _T_contra) -> bool: ...
44
+
45
+
46
+ class SupportsDunderLE(Protocol[_T_contra]):
47
+ def __le__(self, __other: _T_contra) -> bool: ...
48
+
49
+
50
+ class SupportsDunderGE(Protocol[_T_contra]):
51
+ def __ge__(self, __other: _T_contra) -> bool: ...
52
+
53
+
54
+ class SupportsAllComparisons(
55
+ SupportsDunderLT[Any],
56
+ SupportsDunderGT[Any],
57
+ SupportsDunderLE[Any],
58
+ SupportsDunderGE[Any],
59
+ Protocol,
60
+ ): ...
61
+
62
+
63
+ SupportsRichComparison: TypeAlias = SupportsDunderLT[Any] | SupportsDunderGT[Any]
64
+ SupportsRichComparisonT = TypeVar(
65
+ "SupportsRichComparisonT", bound=SupportsRichComparison
66
+ )
67
+
68
+ # ====================
69
+ # Dunder protocols
70
+ # ====================
71
+
72
+
73
+ class SupportsNext(Protocol[_T_co]):
74
+ def __next__(self) -> _T_co: ...
75
+
76
+
77
+ class SupportsAnext(Protocol[_T_co]):
78
+ def __anext__(self) -> Awaitable[_T_co]: ...
79
+
80
+
81
+ class SupportsAdd(Protocol[_T_contra, _T_co]):
82
+ def __add__(self, __x: _T_contra) -> _T_co: ...
83
+
84
+
85
+ class SupportsRAdd(Protocol[_T_contra, _T_co]):
86
+ def __radd__(self, __x: _T_contra) -> _T_co: ...
87
+
88
+
89
+ class SupportsSub(Protocol[_T_contra, _T_co]):
90
+ def __sub__(self, __x: _T_contra) -> _T_co: ...
91
+
92
+
93
+ class SupportsRSub(Protocol[_T_contra, _T_co]):
94
+ def __rsub__(self, __x: _T_contra) -> _T_co: ...
95
+
96
+
97
+ class SupportsDivMod(Protocol[_T_contra, _T_co]):
98
+ def __divmod__(self, __other: _T_contra) -> _T_co: ...
99
+
100
+
101
+ class SupportsRDivMod(Protocol[_T_contra, _T_co]):
102
+ def __rdivmod__(self, __other: _T_contra) -> _T_co: ...
103
+
104
+
105
+ # This protocol is generic over the iterator type, while Iterable is
106
+ # generic over the type that is iterated over.
107
+ class SupportsIter(Protocol[_T_co]):
108
+ def __iter__(self) -> _T_co: ...
109
+
110
+
111
+ # This protocol is generic over the iterator type, while AsyncIterable is
112
+ # generic over the type that is iterated over.
113
+ class SupportsAiter(Protocol[_T_co]):
114
+ def __aiter__(self) -> _T_co: ...
115
+
116
+
117
+ class SupportsLenAndGetItem(Protocol[_T_co]):
118
+ def __len__(self) -> int: ...
119
+ def __getitem__(self, __k: int) -> _T_co: ...
120
+
121
+
122
+ class SupportsTrunc(Protocol):
123
+ def __trunc__(self) -> int: ...
124
+
125
+
126
+ # ====================
127
+ # Mapping-like protocols
128
+ # ====================
129
+
130
+
131
+ class SupportsItems(Protocol[_KT_co, _VT_co]):
132
+ def items(self) -> AbstractSet[tuple[_KT_co, _VT_co]]: ...
133
+
134
+
135
+ class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
136
+ def keys(self) -> Iterable[_KT]: ...
137
+ def __getitem__(self, __key: _KT) -> _VT_co: ...
138
+
139
+
140
+ class SupportsGetItem(Protocol[_KT_contra, _VT_co]):
141
+ def __contains__(self, __x: Any) -> bool: ...
142
+ def __getitem__(self, __key: _KT_contra) -> _VT_co: ...
143
+
144
+
145
+ class SupportsItemAccess(SupportsGetItem[_KT_contra, _VT], Protocol[_KT_contra, _VT]):
146
+ def __setitem__(self, __key: _KT_contra, __value: _VT) -> None: ...
147
+ def __delitem__(self, __key: _KT_contra) -> None: ...
148
+
149
+
150
+ # ====================
151
+ # File handling
152
+ # ====================
153
+
154
+ StrPath: TypeAlias = str | PathLike[str]
155
+ BytesPath: TypeAlias = bytes | PathLike[bytes]
156
+ StrOrBytesPath: TypeAlias = str | bytes | PathLike[str] | PathLike[bytes]
157
+
158
+ OpenTextModeUpdating: TypeAlias = Literal[
159
+ "r+",
160
+ "+r",
161
+ "rt+",
162
+ "r+t",
163
+ "+rt",
164
+ "tr+",
165
+ "t+r",
166
+ "+tr",
167
+ "w+",
168
+ "+w",
169
+ "wt+",
170
+ "w+t",
171
+ "+wt",
172
+ "tw+",
173
+ "t+w",
174
+ "+tw",
175
+ "a+",
176
+ "+a",
177
+ "at+",
178
+ "a+t",
179
+ "+at",
180
+ "ta+",
181
+ "t+a",
182
+ "+ta",
183
+ "x+",
184
+ "+x",
185
+ "xt+",
186
+ "x+t",
187
+ "+xt",
188
+ "tx+",
189
+ "t+x",
190
+ "+tx",
191
+ ]
192
+ OpenTextModeWriting: TypeAlias = Literal[
193
+ "w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"
194
+ ]
195
+ OpenTextModeReading: TypeAlias = Literal[
196
+ "r", "rt", "tr", "U", "rU", "Ur", "rtU", "rUt", "Urt", "trU", "tUr", "Utr"
197
+ ]
198
+ OpenTextMode: TypeAlias = (
199
+ OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading
200
+ )
201
+ OpenBinaryModeUpdating: TypeAlias = Literal[
202
+ "rb+",
203
+ "r+b",
204
+ "+rb",
205
+ "br+",
206
+ "b+r",
207
+ "+br",
208
+ "wb+",
209
+ "w+b",
210
+ "+wb",
211
+ "bw+",
212
+ "b+w",
213
+ "+bw",
214
+ "ab+",
215
+ "a+b",
216
+ "+ab",
217
+ "ba+",
218
+ "b+a",
219
+ "+ba",
220
+ "xb+",
221
+ "x+b",
222
+ "+xb",
223
+ "bx+",
224
+ "b+x",
225
+ "+bx",
226
+ ]
227
+ OpenBinaryModeWriting: TypeAlias = Literal["wb", "bw", "ab", "ba", "xb", "bx"]
228
+ OpenBinaryModeReading: TypeAlias = Literal[
229
+ "rb", "br", "rbU", "rUb", "Urb", "brU", "bUr", "Ubr"
230
+ ]
231
+ OpenBinaryMode: TypeAlias = (
232
+ OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting
233
+ )
234
+
235
+
236
+ class HasFileno(Protocol):
237
+ def fileno(self) -> int: ...
238
+
239
+
240
+ FileDescriptor: TypeAlias = int
241
+ FileDescriptorLike: TypeAlias = int | HasFileno
242
+ FileDescriptorOrPath: TypeAlias = int | StrOrBytesPath
243
+
244
+
245
+ class SupportsRead(Protocol[_T_co]):
246
+ def read(self, __length: int = ...) -> _T_co: ...
247
+
248
+
249
+ class SupportsReadline(Protocol[_T_co]):
250
+ def readline(self, __length: int = ...) -> _T_co: ...
251
+
252
+
253
+ class SupportsNoArgReadline(Protocol[_T_co]):
254
+ def readline(self) -> _T_co: ...
255
+
256
+
257
+ class SupportsWrite(Protocol[_T_contra]):
258
+ def write(self, __s: _T_contra) -> object: ...
259
+
260
+
261
+ # ====================
262
+ # Buffer protocols
263
+ # ====================
264
+
265
+ # Unfortunately PEP 688 does not allow us to distinguish read-only
266
+ # from writable buffers. We use these aliases for readability for now.
267
+ # Perhaps a future extension of the buffer protocol will allow us to
268
+ # distinguish these cases in the type system.
269
+ ReadOnlyBuffer: TypeAlias = Buffer
270
+ # Anything that implements the read-write buffer interface.
271
+ WriteableBuffer: TypeAlias = Buffer
272
+ # Same as WriteableBuffer, but also includes read-only buffer types (like bytes).
273
+ ReadableBuffer: TypeAlias = Buffer
274
+
275
+
276
+ class SliceableBuffer(Buffer, Protocol):
277
+ def __getitem__(self, __slice: slice) -> Sequence[int]: ...
278
+
279
+
280
+ class IndexableBuffer(Buffer, Protocol):
281
+ def __getitem__(self, __i: int) -> int: ...
282
+
283
+
284
+ class SupportsGetItemBuffer(SliceableBuffer, IndexableBuffer, Protocol):
285
+ def __contains__(self, __x: Any) -> bool: ...
286
+ @overload
287
+ def __getitem__(self, __slice: slice) -> Sequence[int]: ...
288
+ @overload
289
+ def __getitem__(self, __i: int) -> int: ...
290
+
291
+
292
+ class SizedBuffer(Sized, Buffer, Protocol): ...
293
+
294
+
295
+ # Source from https://github.com/python/typing/issues/256#issuecomment-1442633430
296
+ # This works because str.__contains__ does not accept object (either in typeshed or at runtime)
297
+ class SequenceNotStr(Protocol[_T_co]):
298
+ @overload
299
+ def __getitem__(self, index: SupportsIndex, /) -> _T_co: ...
300
+ @overload
301
+ def __getitem__(self, index: slice, /) -> Sequence[_T_co]: ...
302
+ def __contains__(self, value: object, /) -> bool: ...
303
+ def __len__(self) -> int: ...
304
+ def __iter__(self) -> Iterator[_T_co]: ...
305
+ def index(self, value: Any, start: int = 0, stop: int = ..., /) -> int: ...
306
+ def count(self, value: Any, /) -> int: ...
307
+ def __reversed__(self) -> Iterator[_T_co]: ...
File without changes