nshtrainer 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nshtrainer/config.py CHANGED
@@ -1,9 +1,21 @@
1
1
  from nshconfig._config import Config as Config
2
2
  from nshsnap._config import SnapshotConfig as SnapshotConfig
3
3
 
4
+ from nshtrainer._checkpoint.loader import (
5
+ BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
6
+ )
4
7
  from nshtrainer._checkpoint.loader import (
5
8
  CheckpointLoadingConfig as CheckpointLoadingConfig,
6
9
  )
10
+ from nshtrainer._checkpoint.loader import (
11
+ CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
12
+ )
13
+ from nshtrainer._checkpoint.loader import (
14
+ LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
15
+ )
16
+ from nshtrainer._checkpoint.loader import (
17
+ UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
18
+ )
7
19
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
8
20
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
9
21
  from nshtrainer._hf_hub import (
@@ -53,48 +65,13 @@ from nshtrainer.callbacks.throughput_monitor import (
53
65
  )
54
66
  from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
55
67
  from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
56
- from nshtrainer.config import UUID1 as UUID1
57
- from nshtrainer.config import UUID3 as UUID3
58
- from nshtrainer.config import UUID4 as UUID4
59
- from nshtrainer.config import UUID5 as UUID5
60
- from nshtrainer.config import AmqpDsn as AmqpDsn
61
- from nshtrainer.config import AnyHttpUrl as AnyHttpUrl
62
- from nshtrainer.config import AnyWebsocketUrl as AnyWebsocketUrl
63
- from nshtrainer.config import Base64Bytes as Base64Bytes
64
- from nshtrainer.config import Base64Str as Base64Str
65
- from nshtrainer.config import Base64UrlBytes as Base64UrlBytes
66
- from nshtrainer.config import Base64UrlStr as Base64UrlStr
67
- from nshtrainer.config import ClickHouseDsn as ClickHouseDsn
68
- from nshtrainer.config import CockroachDsn as CockroachDsn
69
- from nshtrainer.config import DirectoryPath as DirectoryPath
70
- from nshtrainer.config import FilePath as FilePath
71
- from nshtrainer.config import FileUrl as FileUrl
72
- from nshtrainer.config import FiniteFloat as FiniteFloat
73
- from nshtrainer.config import FtpUrl as FtpUrl
74
- from nshtrainer.config import HttpUrl as HttpUrl
75
- from nshtrainer.config import KafkaDsn as KafkaDsn
76
- from nshtrainer.config import MariaDBDsn as MariaDBDsn
77
- from nshtrainer.config import MongoDsn as MongoDsn
78
- from nshtrainer.config import MySQLDsn as MySQLDsn
79
- from nshtrainer.config import NatsDsn as NatsDsn
80
- from nshtrainer.config import NewPath as NewPath
81
- from nshtrainer.config import OnErrorOmit as OnErrorOmit
82
- from nshtrainer.config import PostgresDsn as PostgresDsn
83
- from nshtrainer.config import RedisDsn as RedisDsn
84
- from nshtrainer.config import SnowflakeDsn as SnowflakeDsn
85
- from nshtrainer.config import StrictBool as StrictBool
86
- from nshtrainer.config import StrictBytes as StrictBytes
87
- from nshtrainer.config import StrictFloat as StrictFloat
88
- from nshtrainer.config import StrictInt as StrictInt
89
- from nshtrainer.config import StrictStr as StrictStr
90
- from nshtrainer.config import WebsocketUrl as WebsocketUrl
68
+ from nshtrainer.config import LRSchedulerConfig as LRSchedulerConfig
91
69
  from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
92
70
  from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
93
71
  from nshtrainer.loggers.tensorboard import (
94
72
  TensorboardLoggerConfig as TensorboardLoggerConfig,
95
73
  )
96
74
  from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
97
- from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
98
75
  from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
99
76
  from nshtrainer.lr_scheduler.linear_warmup_cosine import (
100
77
  DurationConfig as DurationConfig,
@@ -164,9 +141,31 @@ from nshtrainer.util._environment_info import (
164
141
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
165
142
  )
166
143
  from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
144
+ from nshtrainer.util._environment_info import (
145
+ EnvironmentCUDAConfig as EnvironmentCUDAConfig,
146
+ )
147
+ from nshtrainer.util._environment_info import (
148
+ EnvironmentGPUConfig as EnvironmentGPUConfig,
149
+ )
150
+ from nshtrainer.util._environment_info import (
151
+ EnvironmentHardwareConfig as EnvironmentHardwareConfig,
152
+ )
167
153
  from nshtrainer.util._environment_info import (
168
154
  EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
169
155
  )
156
+ from nshtrainer.util._environment_info import (
157
+ EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
158
+ )
159
+ from nshtrainer.util._environment_info import (
160
+ EnvironmentPackageConfig as EnvironmentPackageConfig,
161
+ )
170
162
  from nshtrainer.util._environment_info import (
171
163
  EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
172
164
  )
165
+ from nshtrainer.util._environment_info import (
166
+ EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
167
+ )
168
+ from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
169
+ from nshtrainer.util.config.dtype import DTypeConfig as DTypeConfig
170
+ from nshtrainer.util.config.duration import EpochsConfig as EpochsConfig
171
+ from nshtrainer.util.config.duration import StepsConfig as StepsConfig
@@ -1,13 +1,12 @@
1
1
  import heapq
2
2
  import logging
3
- from functools import cached_property
4
3
  from typing import Any, Protocol, runtime_checkable
5
4
 
6
5
  import numpy as np
7
6
  import torch
8
7
  import torch.distributed
9
8
  from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
10
- from torch.utils.data import BatchSampler, Dataset, DistributedSampler
9
+ from torch.utils.data import BatchSampler, DistributedSampler
11
10
  from typing_extensions import override
12
11
 
13
12
  log = logging.getLogger(__name__)
@@ -47,24 +46,16 @@ class DatasetWithSizes(Protocol):
47
46
  def data_sizes(self, indices: list[int]) -> np.ndarray: ...
48
47
 
49
48
 
50
- class BalancedBatchSampler(BatchSampler):
51
- @staticmethod
52
- def _ensure_supported(dataset: Any):
53
- if not isinstance(dataset, Dataset):
54
- raise ValueError(
55
- "BalancedBatchSampler requires a dataset that implements `__getitem__`"
56
- )
57
-
58
- if not isinstance(dataset, DatasetWithSizes):
59
- raise ValueError(
60
- "BalancedBatchSampler requires a dataset that implements `data_sizes`"
61
- )
49
+ @runtime_checkable
50
+ class DataSizesFunction(Protocol):
51
+ def __call__(self, dataset: Any, indices: list[int]) -> np.ndarray: ...
62
52
 
63
- log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
64
- return dataset
65
53
 
54
+ class BalancedBatchSampler(BatchSampler):
66
55
  @staticmethod
67
- def _unwrap_dataset(dataset: Dataset) -> Dataset:
56
+ def _unwrap_dataset(dataset: Any):
57
+ # Lightning's DistributedSampler wraps the dataset in a _DatasetSamplerWrapper,
58
+ # so we need to unwrap it to get the actual dataset.
68
59
  if isinstance(dataset, _DatasetSamplerWrapper):
69
60
  if (data_source := getattr(dataset._sampler, "data_source", None)) is None:
70
61
  raise ValueError("Could not unwrap dataset from _DatasetSamplerWrapper")
@@ -79,12 +70,6 @@ class BalancedBatchSampler(BatchSampler):
79
70
  )
80
71
  return self.sampler
81
72
 
82
- @cached_property
83
- def dataset(self):
84
- return self._ensure_supported(
85
- self._unwrap_dataset(self.distributed_sampler.dataset)
86
- )
87
-
88
73
  def __init__(
89
74
  self,
90
75
  sampler: DistributedSampler,
@@ -92,10 +77,12 @@ class BalancedBatchSampler(BatchSampler):
92
77
  batch_size: int,
93
78
  device: torch.device,
94
79
  drop_last: bool = False,
80
+ data_sizes_fn: DataSizesFunction | None = None,
95
81
  ):
96
82
  super().__init__(sampler, batch_size, drop_last=drop_last)
97
83
 
98
84
  self._device = device
85
+ self._data_sizes_fn = data_sizes_fn
99
86
 
100
87
  log.info(
101
88
  f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}"
@@ -105,17 +92,34 @@ class BalancedBatchSampler(BatchSampler):
105
92
  def _dist_enabled():
106
93
  return torch.distributed.is_available() and torch.distributed.is_initialized()
107
94
 
95
+ def _dataset_sizes(self, indices: list[int]) -> np.ndarray:
96
+ dataset = self._unwrap_dataset(self.distributed_sampler.dataset)
97
+ # Dataset much either implement `data_sizes`, or we need to provide a custom
98
+ # implementation of the dataset sizes function.
99
+ if isinstance(dataset, DatasetWithSizes):
100
+ log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
101
+ return dataset.data_sizes(indices)
102
+
103
+ if (data_sizes_fn := self._data_sizes_fn) is not None:
104
+ return data_sizes_fn(dataset, indices)
105
+
106
+ raise ValueError(
107
+ "Dataset must implement the `data_sizes` method, "
108
+ "or a custom data_sizes_fn must be provided "
109
+ "to the BalancedBatchSampler."
110
+ )
111
+
108
112
  @override
109
113
  def __iter__(self):
110
114
  if not self._dist_enabled():
111
115
  yield from super().__iter__()
112
116
  return
113
117
 
114
- for batch_idx in super().__iter__():
115
- sizes = self.dataset.data_sizes(batch_idx)
118
+ for batch_idxs in super().__iter__():
119
+ sizes = self._dataset_sizes(batch_idxs)
116
120
  idx_sizes = torch.stack(
117
121
  [
118
- torch.tensor(batch_idx, device=self._device),
122
+ torch.tensor(batch_idxs, device=self._device),
119
123
  torch.tensor(sizes, device=self._device),
120
124
  ]
121
125
  )
@@ -18,6 +18,7 @@ from typing_extensions import Unpack, assert_never, override
18
18
 
19
19
  from .._checkpoint.metadata import _write_checkpoint_metadata
20
20
  from ..callbacks.base import resolve_all_callbacks
21
+ from ..util.bf16 import is_bf16_supported_no_emulation
21
22
  from ._config import (
22
23
  AcceleratorConfigProtocol,
23
24
  LightningTrainerKwargs,
@@ -33,30 +34,6 @@ if TYPE_CHECKING:
33
34
  log = logging.getLogger(__name__)
34
35
 
35
36
 
36
- def _is_bf16_supported_no_emulation():
37
- r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
38
- version = getattr(torch, "version")
39
-
40
- # Check for ROCm, if true return true, no ROCM_VERSION check required,
41
- # since it is supported on AMD GPU archs.
42
- if version.hip:
43
- return True
44
-
45
- device = torch.cuda.current_device()
46
-
47
- # Check for CUDA version and device compute capability.
48
- # This is a fast way to check for it.
49
- cuda_version = version.cuda
50
- if (
51
- cuda_version is not None
52
- and int(cuda_version.split(".")[0]) >= 11
53
- and torch.cuda.get_device_properties(device).major >= 8
54
- ):
55
- return True
56
-
57
- return False
58
-
59
-
60
37
  class Trainer(LightningTrainer):
61
38
  @classmethod
62
39
  def _pre_init(cls, config: "BaseConfig"):
@@ -188,7 +165,7 @@ class Trainer(LightningTrainer):
188
165
  try:
189
166
  resolved_precision = (
190
167
  "bf16-mixed"
191
- if _is_bf16_supported_no_emulation()
168
+ if is_bf16_supported_no_emulation()
192
169
  else "16-mixed"
193
170
  )
194
171
  except BaseException:
@@ -0,0 +1,25 @@
1
+ import torch
2
+
3
+
4
+ def is_bf16_supported_no_emulation():
5
+ r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
6
+ version = getattr(torch, "version")
7
+
8
+ # Check for ROCm, if true return true, no ROCM_VERSION check required,
9
+ # since it is supported on AMD GPU archs.
10
+ if version.hip:
11
+ return True
12
+
13
+ device = torch.cuda.current_device()
14
+
15
+ # Check for CUDA version and device compute capability.
16
+ # This is a fast way to check for it.
17
+ cuda_version = version.cuda
18
+ if (
19
+ cuda_version is not None
20
+ and int(cuda_version.split(".")[0]) >= 11
21
+ and torch.cuda.get_device_properties(device).major >= 8
22
+ ):
23
+ return True
24
+
25
+ return False
@@ -1,4 +1,5 @@
1
1
  from . import duration as duration
2
+ from .dtype import DTypeConfig as DTypeConfig
2
3
  from .duration import DurationConfig as DurationConfig
3
4
  from .duration import EpochsConfig as EpochsConfig
4
5
  from .duration import StepsConfig as StepsConfig
@@ -0,0 +1,89 @@
1
+ from typing import TYPE_CHECKING, Literal, TypeAlias
2
+
3
+ import nshconfig as C
4
+ import torch
5
+ from typing_extensions import assert_never
6
+
7
+ from ..bf16 import is_bf16_supported_no_emulation
8
+
9
+ if TYPE_CHECKING:
10
+ from ...model.base import BaseConfig
11
+
12
+ DTypeName: TypeAlias = Literal[
13
+ "float32",
14
+ "float",
15
+ "float64",
16
+ "double",
17
+ "float16",
18
+ "bfloat16",
19
+ "float8_e4m3fn",
20
+ "float8_e4m3fnuz",
21
+ "float8_e5m2",
22
+ "float8_e5m2fnuz",
23
+ "half",
24
+ "uint8",
25
+ "uint16",
26
+ "uint32",
27
+ "uint64",
28
+ "int8",
29
+ "int16",
30
+ "short",
31
+ "int32",
32
+ "int",
33
+ "int64",
34
+ "long",
35
+ "complex32",
36
+ "complex64",
37
+ "chalf",
38
+ "cfloat",
39
+ "complex128",
40
+ "cdouble",
41
+ "quint8",
42
+ "qint8",
43
+ "qint32",
44
+ "bool",
45
+ "quint4x2",
46
+ "quint2x4",
47
+ "bits1x8",
48
+ "bits2x4",
49
+ "bits4x2",
50
+ "bits8",
51
+ "bits16",
52
+ ]
53
+
54
+
55
+ class DTypeConfig(C.Config):
56
+ name: DTypeName
57
+ """The name of the dtype."""
58
+
59
+ @classmethod
60
+ def from_base_config(cls, config: "BaseConfig"):
61
+ if (precision := config.trainer.precision) is None:
62
+ precision = "32-true"
63
+
64
+ match precision:
65
+ case "16-mixed-auto":
66
+ return (
67
+ cls(name="bfloat16")
68
+ if is_bf16_supported_no_emulation()
69
+ else cls(name="float16")
70
+ )
71
+ case "fp16-mixed":
72
+ return cls(name="float16")
73
+ case "bf16-mixed":
74
+ return cls(name="bfloat16")
75
+ case "32-true":
76
+ return cls(name="float32")
77
+ case "64-true":
78
+ return cls(name="float64")
79
+ case _:
80
+ assert_never(config.trainer.precision)
81
+
82
+ @property
83
+ def torch_dtype(self):
84
+ if ((dtype := getattr(torch, self.name, None)) is None) or not isinstance(
85
+ dtype, torch.dtype
86
+ ):
87
+ raise ValueError(f"Unknown dtype {self.name}")
88
+
89
+ return dtype
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.33.1
3
+ Version: 0.34.0
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -30,9 +30,9 @@ nshtrainer/callbacks/shared_parameters.py,sha256=fqlDweFDXPV_bfcAWpRgaJIad9i5Aeh
30
30
  nshtrainer/callbacks/throughput_monitor.py,sha256=H_ocXErZxUO3dxFk8Tx_VQdpI9E_Ztvqof5WtFevLyQ,1838
31
31
  nshtrainer/callbacks/timer.py,sha256=quS79oYClDUvQxJkNWmDMe0hwRUkkREgTgqzVrnom50,4607
32
32
  nshtrainer/callbacks/wandb_watch.py,sha256=Y6SEXfIx3kDDQbI5zpP53BVq0FBLJbLd3RJsiHZk1-Y,2921
33
- nshtrainer/config.py,sha256=9Hmgb-2BttQwFcp1wO5hyKyYzKZ_CYYI7RZbfcmxOzE,8762
33
+ nshtrainer/config.py,sha256=skar_Wfz50_sU2NZS8PEjqofWeon4g4cyIgby3Da81g,8308
34
34
  nshtrainer/data/__init__.py,sha256=7mk1tr7SWUZ7ySbsf0y0ZPszk7u4QznPhQ-7wnpH9ec,149
35
- nshtrainer/data/balanced_batch_sampler.py,sha256=dGBTDDtlBU6c-ZlVQOCnTW7SjTB5hczWsOWEdUWjvkA,4385
35
+ nshtrainer/data/balanced_batch_sampler.py,sha256=WAjhbO9EsZ_UadhdW3obBsjvEDMc2V-irpjegqIb7AI,4791
36
36
  nshtrainer/data/transform.py,sha256=6SNs3_TpNpfhcwTwvPKyEJ3opM1OT7LmMEYQNHKgRl8,2227
37
37
  nshtrainer/ll/__init__.py,sha256=L-aTi1V1bbvnZjOro8NvI393zbHQSFR9movWSRK9Mds,2477
38
38
  nshtrainer/ll/_experimental.py,sha256=oBQCKOEVYoxuUU9eLb-Fg2B2mzZD7SA0zfAO6lmWZ88,53
@@ -84,10 +84,12 @@ nshtrainer/trainer/_config.py,sha256=ZIodM5Ek1lpkWFhQ_VfmKR7q1mZFFwtjfx8FH72H8WM
84
84
  nshtrainer/trainer/_runtime_callback.py,sha256=sd2cUdRJG-UCdQr9ruZvEYpNGNF1t2W2fuxwwVlQD9E,4164
85
85
  nshtrainer/trainer/checkpoint_connector.py,sha256=r0ir4xYSdf_jebM0x09qaO6nJsvsiRQDyM0fs80ppOQ,2347
86
86
  nshtrainer/trainer/signal_connector.py,sha256=2EzkVktlasl8PgWAKNLDZRUMY__gRlDy1HdinAU-tfU,10740
87
- nshtrainer/trainer/trainer.py,sha256=iYueHW-m8fHyC8SQuXmpgxq_-GUa7pAJik7rDFPXmy0,17499
87
+ nshtrainer/trainer/trainer.py,sha256=8T4LB31ygXXS3DECkvD2uqgElAxkulacYvZyL_-imJs,16839
88
88
  nshtrainer/util/_environment_info.py,sha256=CFUUZYjXhBLWGc0jtPNOaZgYMueUDEHpEaWFA1f3GoY,24213
89
89
  nshtrainer/util/_useful_types.py,sha256=dwZokFkIe7M5i2GR3nQ9A1lhGw06DMAFfH5atyquqSA,8000
90
- nshtrainer/util/config/__init__.py,sha256=N2AOhaZC93DszvCdwvNL9KgnzJ2M3P-esFBY6VGih6Y,190
90
+ nshtrainer/util/bf16.py,sha256=VUnIG6aA4XtZscZc_dxv5ln_jlEbdU3eMFwDb5SEWSI,726
91
+ nshtrainer/util/config/__init__.py,sha256=o8fwPf_dctE_7CAkT0wNOBkvmxnzYzXeHpLedrZLt54,236
92
+ nshtrainer/util/config/dtype.py,sha256=JtYjrcBFNBlziJnLAE6QS0QV4PUXhGspYH1hNFrB3ks,1965
91
93
  nshtrainer/util/config/duration.py,sha256=pgIKQ88Dg8y1YAKUvUsNWu9hc9O79kdYBfgmC3a_-kQ,728
92
94
  nshtrainer/util/environment.py,sha256=AeW_kLl-N70wmb6L_JLz1wRj0kA70xs6RCmc9iUqczE,4159
93
95
  nshtrainer/util/path.py,sha256=VkpuhR4GaZtSFBVqbGAvfjcrU-PR8xwiGzzwFNOWP9c,2995
@@ -95,6 +97,6 @@ nshtrainer/util/seed.py,sha256=Or2wMPsnQxfnZ2xfBiyMcHFIUt3tGTNeMMyOEanCkqs,280
95
97
  nshtrainer/util/slurm.py,sha256=rofIU26z3SdL79SF45tNez6juou1cyDLz07oXEZb9Hg,1566
96
98
  nshtrainer/util/typed.py,sha256=NGuDkDzFlc1fAoaXjOFZVbmj0mRFjsQi1E_hPa7Bn5U,128
97
99
  nshtrainer/util/typing_utils.py,sha256=8ptjSSLZxlmy4FY6lzzkoGoF5fGNClo8-B_c0XHQaNU,385
98
- nshtrainer-0.33.1.dist-info/METADATA,sha256=TNb9UWbyEqq3Yt7Dp5NYttYog3AIGlo4SI-_HbC5s3Y,916
99
- nshtrainer-0.33.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
100
- nshtrainer-0.33.1.dist-info/RECORD,,
100
+ nshtrainer-0.34.0.dist-info/METADATA,sha256=GYC9ejdKV3MCyOFhJcFjI-uedTWLGWj-SE5S79ruug4,916
101
+ nshtrainer-0.34.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
102
+ nshtrainer-0.34.0.dist-info/RECORD,,