nshtrainer 0.33.2__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/config.py CHANGED
@@ -1,9 +1,21 @@
1
1
  from nshconfig._config import Config as Config
2
2
  from nshsnap._config import SnapshotConfig as SnapshotConfig
3
3
 
4
+ from nshtrainer._checkpoint.loader import (
5
+ BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
6
+ )
4
7
  from nshtrainer._checkpoint.loader import (
5
8
  CheckpointLoadingConfig as CheckpointLoadingConfig,
6
9
  )
10
+ from nshtrainer._checkpoint.loader import (
11
+ CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
12
+ )
13
+ from nshtrainer._checkpoint.loader import (
14
+ LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
15
+ )
16
+ from nshtrainer._checkpoint.loader import (
17
+ UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
18
+ )
7
19
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
8
20
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
9
21
  from nshtrainer._hf_hub import (
@@ -53,13 +65,13 @@ from nshtrainer.callbacks.throughput_monitor import (
53
65
  )
54
66
  from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
55
67
  from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
68
+ from nshtrainer.config import LRSchedulerConfig as LRSchedulerConfig
56
69
  from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
57
70
  from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
58
71
  from nshtrainer.loggers.tensorboard import (
59
72
  TensorboardLoggerConfig as TensorboardLoggerConfig,
60
73
  )
61
74
  from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
62
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
63
75
  from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
64
76
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
65
77
  DurationConfig as DurationConfig,
@@ -129,9 +141,31 @@ from nshtrainer.util._environment_info import (
129
141
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
130
142
  )
131
143
  from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
144
+ from nshtrainer.util._environment_info import (
145
+ EnvironmentCUDAConfig as EnvironmentCUDAConfig,
146
+ )
147
+ from nshtrainer.util._environment_info import (
148
+ EnvironmentGPUConfig as EnvironmentGPUConfig,
149
+ )
150
+ from nshtrainer.util._environment_info import (
151
+ EnvironmentHardwareConfig as EnvironmentHardwareConfig,
152
+ )
132
153
  from nshtrainer.util._environment_info import (
133
154
  EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
134
155
  )
156
+ from nshtrainer.util._environment_info import (
157
+ EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
158
+ )
159
+ from nshtrainer.util._environment_info import (
160
+ EnvironmentPackageConfig as EnvironmentPackageConfig,
161
+ )
135
162
  from nshtrainer.util._environment_info import (
136
163
  EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
137
164
  )
165
+ from nshtrainer.util._environment_info import (
166
+ EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
167
+ )
168
+ from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
169
+ from nshtrainer.util.config.dtype import DTypeConfig as DTypeConfig
170
+ from nshtrainer.util.config.duration import EpochsConfig as EpochsConfig
171
+ from nshtrainer.util.config.duration import StepsConfig as StepsConfig
@@ -1,13 +1,12 @@
1
1
  import heapq
2
2
  import logging
3
- from functools import cached_property
4
3
  from typing import Any, Protocol, runtime_checkable
5
4
 
6
5
  import numpy as np
7
6
  import torch
8
7
  import torch.distributed
9
8
  from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
10
- from torch.utils.data import BatchSampler, Dataset, DistributedSampler
9
+ from torch.utils.data import BatchSampler, DistributedSampler
11
10
  from typing_extensions import override
12
11
 
13
12
  log = logging.getLogger(__name__)
@@ -47,24 +46,16 @@ class DatasetWithSizes(Protocol):
47
46
  def data_sizes(self, indices: list[int]) -> np.ndarray: ...
48
47
 
49
48
 
50
- class BalancedBatchSampler(BatchSampler):
51
- @staticmethod
52
- def _ensure_supported(dataset: Any):
53
- if not isinstance(dataset, Dataset):
54
- raise ValueError(
55
- "BalancedBatchSampler requires a dataset that implements `__getitem__`"
56
- )
57
-
58
- if not isinstance(dataset, DatasetWithSizes):
59
- raise ValueError(
60
- "BalancedBatchSampler requires a dataset that implements `data_sizes`"
61
- )
49
+ @runtime_checkable
50
+ class DataSizesFunction(Protocol):
51
+ def __call__(self, dataset: Any, indices: list[int]) -> np.ndarray: ...
62
52
 
63
- log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
64
- return dataset
65
53
 
54
+ class BalancedBatchSampler(BatchSampler):
66
55
  @staticmethod
67
- def _unwrap_dataset(dataset: Dataset) -> Dataset:
56
+ def _unwrap_dataset(dataset: Any):
57
+ # Lightning's DistributedSampler wraps the dataset in a _DatasetSamplerWrapper,
58
+ # so we need to unwrap it to get the actual dataset.
68
59
  if isinstance(dataset, _DatasetSamplerWrapper):
69
60
  if (data_source := getattr(dataset._sampler, "data_source", None)) is None:
70
61
  raise ValueError("Could not unwrap dataset from _DatasetSamplerWrapper")
@@ -79,12 +70,6 @@ class BalancedBatchSampler(BatchSampler):
79
70
  )
80
71
  return self.sampler
81
72
 
82
- @cached_property
83
- def dataset(self):
84
- return self._ensure_supported(
85
- self._unwrap_dataset(self.distributed_sampler.dataset)
86
- )
87
-
88
73
  def __init__(
89
74
  self,
90
75
  sampler: DistributedSampler,
@@ -92,10 +77,12 @@ class BalancedBatchSampler(BatchSampler):
92
77
  batch_size: int,
93
78
  device: torch.device,
94
79
  drop_last: bool = False,
80
+ data_sizes_fn: DataSizesFunction | None = None,
95
81
  ):
96
82
  super().__init__(sampler, batch_size, drop_last=drop_last)
97
83
 
98
84
  self._device = device
85
+ self._data_sizes_fn = data_sizes_fn
99
86
 
100
87
  log.info(
101
88
  f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}"
@@ -105,17 +92,34 @@ class BalancedBatchSampler(BatchSampler):
105
92
  def _dist_enabled():
106
93
  return torch.distributed.is_available() and torch.distributed.is_initialized()
107
94
 
95
+ def _dataset_sizes(self, indices: list[int]) -> np.ndarray:
96
+ dataset = self._unwrap_dataset(self.distributed_sampler.dataset)
97
+ # Dataset much either implement `data_sizes`, or we need to provide a custom
98
+ # implementation of the dataset sizes function.
99
+ if isinstance(dataset, DatasetWithSizes):
100
+ log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
101
+ return dataset.data_sizes(indices)
102
+
103
+ if (data_sizes_fn := self._data_sizes_fn) is not None:
104
+ return data_sizes_fn(dataset, indices)
105
+
106
+ raise ValueError(
107
+ "Dataset must implement the `data_sizes` method, "
108
+ "or a custom data_sizes_fn must be provided "
109
+ "to the BalancedBatchSampler."
110
+ )
111
+
108
112
  @override
109
113
  def __iter__(self):
110
114
  if not self._dist_enabled():
111
115
  yield from super().__iter__()
112
116
  return
113
117
 
114
- for batch_idx in super().__iter__():
115
- sizes = self.dataset.data_sizes(batch_idx)
118
+ for batch_idxs in super().__iter__():
119
+ sizes = self._dataset_sizes(batch_idxs)
116
120
  idx_sizes = torch.stack(
117
121
  [
118
- torch.tensor(batch_idx, device=self._device),
122
+ torch.tensor(batch_idxs, device=self._device),
119
123
  torch.tensor(sizes, device=self._device),
120
124
  ]
121
125
  )
@@ -18,6 +18,7 @@ from typing_extensions import Unpack, assert_never, override
18
18
 
19
19
  from .._checkpoint.metadata import _write_checkpoint_metadata
20
20
  from ..callbacks.base import resolve_all_callbacks
21
+ from ..util.bf16 import is_bf16_supported_no_emulation
21
22
  from ._config import (
22
23
  AcceleratorConfigProtocol,
23
24
  LightningTrainerKwargs,
@@ -33,30 +34,6 @@ if TYPE_CHECKING:
33
34
  log = logging.getLogger(__name__)
34
35
 
35
36
 
36
- def _is_bf16_supported_no_emulation():
37
- r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
38
- version = getattr(torch, "version")
39
-
40
- # Check for ROCm, if true return true, no ROCM_VERSION check required,
41
- # since it is supported on AMD GPU archs.
42
- if version.hip:
43
- return True
44
-
45
- device = torch.cuda.current_device()
46
-
47
- # Check for CUDA version and device compute capability.
48
- # This is a fast way to check for it.
49
- cuda_version = version.cuda
50
- if (
51
- cuda_version is not None
52
- and int(cuda_version.split(".")[0]) >= 11
53
- and torch.cuda.get_device_properties(device).major >= 8
54
- ):
55
- return True
56
-
57
- return False
58
-
59
-
60
37
  class Trainer(LightningTrainer):
61
38
  @classmethod
62
39
  def _pre_init(cls, config: "BaseConfig"):
@@ -188,7 +165,7 @@ class Trainer(LightningTrainer):
188
165
  try:
189
166
  resolved_precision = (
190
167
  "bf16-mixed"
191
- if _is_bf16_supported_no_emulation()
168
+ if is_bf16_supported_no_emulation()
192
169
  else "16-mixed"
193
170
  )
194
171
  except BaseException:
@@ -0,0 +1,25 @@
1
+ import torch
2
+
3
+
4
+ def is_bf16_supported_no_emulation():
5
+ r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
6
+ version = getattr(torch, "version")
7
+
8
+ # Check for ROCm, if true return true, no ROCM_VERSION check required,
9
+ # since it is supported on AMD GPU archs.
10
+ if version.hip:
11
+ return True
12
+
13
+ device = torch.cuda.current_device()
14
+
15
+ # Check for CUDA version and device compute capability.
16
+ # This is a fast way to check for it.
17
+ cuda_version = version.cuda
18
+ if (
19
+ cuda_version is not None
20
+ and int(cuda_version.split(".")[0]) >= 11
21
+ and torch.cuda.get_device_properties(device).major >= 8
22
+ ):
23
+ return True
24
+
25
+ return False
@@ -1,4 +1,5 @@
1
1
  from . import duration as duration
2
+ from .dtype import DTypeConfig as DTypeConfig
2
3
  from .duration import DurationConfig as DurationConfig
3
4
  from .duration import EpochsConfig as EpochsConfig
4
5
  from .duration import StepsConfig as StepsConfig
@@ -0,0 +1,89 @@
1
+ from typing import TYPE_CHECKING, Literal, TypeAlias
2
+
3
+ import nshconfig as C
4
+ import torch
5
+ from typing_extensions import assert_never
6
+
7
+ from ..bf16 import is_bf16_supported_no_emulation
8
+
9
+ if TYPE_CHECKING:
10
+ from ...model.base import BaseConfig
11
+
12
+ DTypeName: TypeAlias = Literal[
13
+ "float32",
14
+ "float",
15
+ "float64",
16
+ "double",
17
+ "float16",
18
+ "bfloat16",
19
+ "float8_e4m3fn",
20
+ "float8_e4m3fnuz",
21
+ "float8_e5m2",
22
+ "float8_e5m2fnuz",
23
+ "half",
24
+ "uint8",
25
+ "uint16",
26
+ "uint32",
27
+ "uint64",
28
+ "int8",
29
+ "int16",
30
+ "short",
31
+ "int32",
32
+ "int",
33
+ "int64",
34
+ "long",
35
+ "complex32",
36
+ "complex64",
37
+ "chalf",
38
+ "cfloat",
39
+ "complex128",
40
+ "cdouble",
41
+ "quint8",
42
+ "qint8",
43
+ "qint32",
44
+ "bool",
45
+ "quint4x2",
46
+ "quint2x4",
47
+ "bits1x8",
48
+ "bits2x4",
49
+ "bits4x2",
50
+ "bits8",
51
+ "bits16",
52
+ ]
53
+
54
+
55
+ class DTypeConfig(C.Config):
56
+ name: DTypeName
57
+ """The name of the dtype."""
58
+
59
+ @classmethod
60
+ def from_base_config(cls, config: "BaseConfig"):
61
+ if (precision := config.trainer.precision) is None:
62
+ precision = "32-true"
63
+
64
+ match precision:
65
+ case "16-mixed-auto":
66
+ return (
67
+ cls(name="bfloat16")
68
+ if is_bf16_supported_no_emulation()
69
+ else cls(name="float16")
70
+ )
71
+ case "fp16-mixed":
72
+ return cls(name="float16")
73
+ case "bf16-mixed":
74
+ return cls(name="bfloat16")
75
+ case "32-true":
76
+ return cls(name="float32")
77
+ case "64-true":
78
+ return cls(name="float64")
79
+ case _:
80
+ assert_never(config.trainer.precision)
81
+
82
+ @property
83
+ def torch_dtype(self):
84
+ if ((dtype := getattr(torch, self.name, None)) is None) or not isinstance(
85
+ dtype, torch.dtype
86
+ ):
87
+ raise ValueError(f"Unknown dtype {self.name}")
88
+
89
+ return dtype
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.33.2
3
+ Version: 0.34.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -30,9 +30,9 @@ nshtrainer/callbacks/shared_parameters.py,sha256=fqlDweFDXPV_bfcAWpRgaJIad9i5Aeh
30
30
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
31
31
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
32
32
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
33
- nshtrainer/config.py,sha256=EqvSp06RmSkCvo13-5bkecoCcE1nVViwvFIivTZOXoI,6883
33
+ nshtrainer/config.py,sha256=skar_Wfz50_sU2NZS8PEjqofWeon4g4cyIgby3Da81g,8308
34
34
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
35
- nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
35
+ nshtrainer/data/balanced_batch_sampler.py,sha256=WAjhbO9EsZ_UadhdW3obBsjvEDMc2V-irpjegqIb7AI,4791
36
36
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
37
37
  nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
38
38
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
@@ -84,10 +84,12 @@ nshtrainer/trainer/_config.py,sha256=ZIodM5Ek1lpkWFhQ_VfmKR7q1mZFFwtjfx8FH72H8WM
84
84
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
85
85
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
86
86
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
87
- nshtrainer/trainer/trainer.py,sha256=iYueHW-m8fHyC8SQuXmpgxq_-GUa7pAJik7rDFPXmy0,17499
87
+ nshtrainer/trainer/trainer.py,sha256=8T4LB31ygXXS3DECkvD2uqgElAxkulacYvZyL_-imJs,16839
88
88
  nshtrainer/util/_environment_info.py,sha256=CFUUZYjXhBLWGc0jtPNOaZgYMueUDEHpEaWFA1f3GoY,24213
89
89
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
90
- nshtrainer/util/config/__init__.py,sha256=N2AOhaZC93DszvCdwvNL9KgnzJ2M3P-esFBY6VGih6Y,190
90
+ nshtrainer/util/bf16.py,sha256=VUnIG6aA4XtZscZc_dxv5ln_jlEbdU3eMFwDb5SEWSI,726
91
+ nshtrainer/util/config/__init__.py,sha256=o8fwPf_dctE_7CAkT0wNOBkvmxnzYzXeHpLedrZLt54,236
92
+ nshtrainer/util/config/dtype.py,sha256=JtYjrcBFNBlziJnLAE6QS0QV4PUXhGspYH1hNFrB3ks,1965
91
93
  nshtrainer/util/config/duration.py,sha256=pgIKQ88Dg8y1YAKUvUsNWu9hc9O79kdYBfgmC3a_-kQ,728
92
94
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
93
95
  nshtrainer/util/path.py,sha256=VkpuhR4GaZtSFBVqbGAvfjcrU-PR8xwiGzzwFNOWP9c,2995
@@ -95,6 +97,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
95
97
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
96
98
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
97
99
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
98
- nshtrainer-0.33.2.dist-info/METADATA,sha256=zDC_xehJGE3RlCACScFpu64qL1TKd_D8VyhjmRxNDkw,916
99
- nshtrainer-0.33.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
100
- nshtrainer-0.33.2.dist-info/RECORD,,
100
+ nshtrainer-0.34.0.dist-info/METADATA,sha256=GYC9ejdKV3MCyOFhJcFjI-uedTWLGWj-SE5S79ruug4,916
101
+ nshtrainer-0.34.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
102
+ nshtrainer-0.34.0.dist-info/RECORD,,