nshtrainer 0.33.2__tar.gz → 0.34.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/PKG-INFO +1 -1
  2. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/pyproject.toml +1 -1
  3. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/config.py +34 -3
  4. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/data/balanced_batch_sampler.py +30 -26
  5. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/trainer/trainer.py +2 -25
  6. nshtrainer-0.34.1/src/nshtrainer/util/bf16.py +25 -0
  7. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/config/__init__.py +1 -0
  8. nshtrainer-0.34.1/src/nshtrainer/util/config/dtype.py +89 -0
  9. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/README.md +0 -0
  10. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/__init__.py +0 -0
  11. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/_callback.py +0 -0
  12. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/_checkpoint/loader.py +0 -0
  13. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  14. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/_checkpoint/saver.py +0 -0
  15. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/_directory.py +0 -0
  16. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/_experimental/__init__.py +0 -0
  17. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/_hf_hub.py +0 -0
  18. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/__init__.py +0 -0
  19. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  20. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/actsave.py +0 -0
  21. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/base.py +0 -0
  22. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
  23. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
  24. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
  25. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
  26. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
  27. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/debug_flag.py +0 -0
  28. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/directory_setup.py +0 -0
  29. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  30. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/ema.py +0 -0
  31. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  32. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  33. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/interval.py +0 -0
  34. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  35. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  36. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/print_table.py +0 -0
  37. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
  38. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
  39. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  40. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/timer.py +0 -0
  41. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  42. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/data/__init__.py +0 -0
  43. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/data/transform.py +0 -0
  44. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/__init__.py +0 -0
  45. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/_experimental.py +0 -0
  46. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/actsave.py +0 -0
  47. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/callbacks.py +0 -0
  48. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/config.py +0 -0
  49. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/data.py +0 -0
  50. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/log.py +0 -0
  51. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  52. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/model.py +0 -0
  53. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/nn.py +0 -0
  54. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/optimizer.py +0 -0
  55. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/runner.py +0 -0
  56. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/snapshot.py +0 -0
  57. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/snoop.py +0 -0
  58. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/trainer.py +0 -0
  59. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/typecheck.py +0 -0
  60. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/ll/util.py +0 -0
  61. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/loggers/__init__.py +0 -0
  62. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/loggers/_base.py +0 -0
  63. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/loggers/csv.py +0 -0
  64. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/loggers/tensorboard.py +0 -0
  65. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/loggers/wandb.py +0 -0
  66. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  67. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  68. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  69. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  70. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/metrics/__init__.py +0 -0
  71. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/metrics/_config.py +0 -0
  72. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/model/__init__.py +0 -0
  73. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/model/base.py +0 -0
  74. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/model/config.py +0 -0
  75. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/model/mixins/logger.py +0 -0
  76. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/nn/__init__.py +0 -0
  77. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/nn/mlp.py +0 -0
  78. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/nn/module_dict.py +0 -0
  79. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/nn/module_list.py +0 -0
  80. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/nn/nonlinearity.py +0 -0
  81. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/optimizer.py +0 -0
  82. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/profiler/__init__.py +0 -0
  83. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/profiler/_base.py +0 -0
  84. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/profiler/advanced.py +0 -0
  85. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/profiler/pytorch.py +0 -0
  86. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/profiler/simple.py +0 -0
  87. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/runner.py +0 -0
  88. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/scripts/find_packages.py +0 -0
  89. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/trainer/__init__.py +0 -0
  90. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/trainer/_config.py +0 -0
  91. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  92. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  93. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/trainer/signal_connector.py +0 -0
  94. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/_environment_info.py +0 -0
  95. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/_useful_types.py +0 -0
  96. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/config/duration.py +0 -0
  97. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/environment.py +0 -0
  98. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/path.py +0 -0
  99. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/seed.py +0 -0
  100. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/slurm.py +0 -0
  101. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/typed.py +0 -0
  102. {nshtrainer-0.33.2 → nshtrainer-0.34.1}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.33.2
3
+ Version: 0.34.1
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.33.2"
3
+ version = "0.34.1"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -1,9 +1,18 @@
1
- from nshconfig._config import Config as Config
2
- from nshsnap._config import SnapshotConfig as SnapshotConfig
3
-
1
+ from nshtrainer._checkpoint.loader import (
2
+ BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
3
+ )
4
4
  from nshtrainer._checkpoint.loader import (
5
5
  CheckpointLoadingConfig as CheckpointLoadingConfig,
6
6
  )
7
+ from nshtrainer._checkpoint.loader import (
8
+ CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
9
+ )
10
+ from nshtrainer._checkpoint.loader import (
11
+ LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
12
+ )
13
+ from nshtrainer._checkpoint.loader import (
14
+ UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
15
+ )
7
16
  from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
8
17
  from nshtrainer._directory import DirectoryConfig as DirectoryConfig
9
18
  from nshtrainer._hf_hub import (
@@ -129,9 +138,31 @@ from nshtrainer.util._environment_info import (
129
138
  EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
130
139
  )
131
140
  from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
141
+ from nshtrainer.util._environment_info import (
142
+ EnvironmentCUDAConfig as EnvironmentCUDAConfig,
143
+ )
144
+ from nshtrainer.util._environment_info import (
145
+ EnvironmentGPUConfig as EnvironmentGPUConfig,
146
+ )
147
+ from nshtrainer.util._environment_info import (
148
+ EnvironmentHardwareConfig as EnvironmentHardwareConfig,
149
+ )
132
150
  from nshtrainer.util._environment_info import (
133
151
  EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
134
152
  )
153
+ from nshtrainer.util._environment_info import (
154
+ EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
155
+ )
156
+ from nshtrainer.util._environment_info import (
157
+ EnvironmentPackageConfig as EnvironmentPackageConfig,
158
+ )
135
159
  from nshtrainer.util._environment_info import (
136
160
  EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
137
161
  )
162
+ from nshtrainer.util._environment_info import (
163
+ EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
164
+ )
165
+ from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
166
+ from nshtrainer.util.config.dtype import DTypeConfig as DTypeConfig
167
+ from nshtrainer.util.config.duration import EpochsConfig as EpochsConfig
168
+ from nshtrainer.util.config.duration import StepsConfig as StepsConfig
@@ -1,13 +1,12 @@
1
1
  import heapq
2
2
  import logging
3
- from functools import cached_property
4
3
  from typing import Any, Protocol, runtime_checkable
5
4
 
6
5
  import numpy as np
7
6
  import torch
8
7
  import torch.distributed
9
8
  from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
10
- from torch.utils.data import BatchSampler, Dataset, DistributedSampler
9
+ from torch.utils.data import BatchSampler, DistributedSampler
11
10
  from typing_extensions import override
12
11
 
13
12
  log = logging.getLogger(__name__)
@@ -47,24 +46,16 @@ class DatasetWithSizes(Protocol):
47
46
  def data_sizes(self, indices: list[int]) -> np.ndarray: ...
48
47
 
49
48
 
50
- class BalancedBatchSampler(BatchSampler):
51
- @staticmethod
52
- def _ensure_supported(dataset: Any):
53
- if not isinstance(dataset, Dataset):
54
- raise ValueError(
55
- "BalancedBatchSampler requires a dataset that implements `__getitem__`"
56
- )
57
-
58
- if not isinstance(dataset, DatasetWithSizes):
59
- raise ValueError(
60
- "BalancedBatchSampler requires a dataset that implements `data_sizes`"
61
- )
49
+ @runtime_checkable
50
+ class DataSizesFunction(Protocol):
51
+ def __call__(self, dataset: Any, indices: list[int]) -> np.ndarray: ...
62
52
 
63
- log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
64
- return dataset
65
53
 
54
+ class BalancedBatchSampler(BatchSampler):
66
55
  @staticmethod
67
- def _unwrap_dataset(dataset: Dataset) -> Dataset:
56
+ def _unwrap_dataset(dataset: Any):
57
+ # Lightning's DistributedSampler wraps the dataset in a _DatasetSamplerWrapper,
58
+ # so we need to unwrap it to get the actual dataset.
68
59
  if isinstance(dataset, _DatasetSamplerWrapper):
69
60
  if (data_source := getattr(dataset._sampler, "data_source", None)) is None:
70
61
  raise ValueError("Could not unwrap dataset from _DatasetSamplerWrapper")
@@ -79,12 +70,6 @@ class BalancedBatchSampler(BatchSampler):
79
70
  )
80
71
  return self.sampler
81
72
 
82
- @cached_property
83
- def dataset(self):
84
- return self._ensure_supported(
85
- self._unwrap_dataset(self.distributed_sampler.dataset)
86
- )
87
-
88
73
  def __init__(
89
74
  self,
90
75
  sampler: DistributedSampler,
@@ -92,10 +77,12 @@ class BalancedBatchSampler(BatchSampler):
92
77
  batch_size: int,
93
78
  device: torch.device,
94
79
  drop_last: bool = False,
80
+ data_sizes_fn: DataSizesFunction | None = None,
95
81
  ):
96
82
  super().__init__(sampler, batch_size, drop_last=drop_last)
97
83
 
98
84
  self._device = device
85
+ self._data_sizes_fn = data_sizes_fn
99
86
 
100
87
  log.info(
101
88
  f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}"
@@ -105,17 +92,34 @@ class BalancedBatchSampler(BatchSampler):
105
92
  def _dist_enabled():
106
93
  return torch.distributed.is_available() and torch.distributed.is_initialized()
107
94
 
95
+ def _dataset_sizes(self, indices: list[int]) -> np.ndarray:
96
+ dataset = self._unwrap_dataset(self.distributed_sampler.dataset)
97
+ # Dataset much either implement `data_sizes`, or we need to provide a custom
98
+ # implementation of the dataset sizes function.
99
+ if isinstance(dataset, DatasetWithSizes):
100
+ log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
101
+ return dataset.data_sizes(indices)
102
+
103
+ if (data_sizes_fn := self._data_sizes_fn) is not None:
104
+ return data_sizes_fn(dataset, indices)
105
+
106
+ raise ValueError(
107
+ "Dataset must implement the `data_sizes` method, "
108
+ "or a custom data_sizes_fn must be provided "
109
+ "to the BalancedBatchSampler."
110
+ )
111
+
108
112
  @override
109
113
  def __iter__(self):
110
114
  if not self._dist_enabled():
111
115
  yield from super().__iter__()
112
116
  return
113
117
 
114
- for batch_idx in super().__iter__():
115
- sizes = self.dataset.data_sizes(batch_idx)
118
+ for batch_idxs in super().__iter__():
119
+ sizes = self._dataset_sizes(batch_idxs)
116
120
  idx_sizes = torch.stack(
117
121
  [
118
- torch.tensor(batch_idx, device=self._device),
122
+ torch.tensor(batch_idxs, device=self._device),
119
123
  torch.tensor(sizes, device=self._device),
120
124
  ]
121
125
  )
@@ -18,6 +18,7 @@ from typing_extensions import Unpack, assert_never, override
18
18
 
19
19
  from .._checkpoint.metadata import _write_checkpoint_metadata
20
20
  from ..callbacks.base import resolve_all_callbacks
21
+ from ..util.bf16 import is_bf16_supported_no_emulation
21
22
  from ._config import (
22
23
  AcceleratorConfigProtocol,
23
24
  LightningTrainerKwargs,
@@ -33,30 +34,6 @@ if TYPE_CHECKING:
33
34
  log = logging.getLogger(__name__)
34
35
 
35
36
 
36
- def _is_bf16_supported_no_emulation():
37
- r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
38
- version = getattr(torch, "version")
39
-
40
- # Check for ROCm, if true return true, no ROCM_VERSION check required,
41
- # since it is supported on AMD GPU archs.
42
- if version.hip:
43
- return True
44
-
45
- device = torch.cuda.current_device()
46
-
47
- # Check for CUDA version and device compute capability.
48
- # This is a fast way to check for it.
49
- cuda_version = version.cuda
50
- if (
51
- cuda_version is not None
52
- and int(cuda_version.split(".")[0]) >= 11
53
- and torch.cuda.get_device_properties(device).major >= 8
54
- ):
55
- return True
56
-
57
- return False
58
-
59
-
60
37
  class Trainer(LightningTrainer):
61
38
  @classmethod
62
39
  def _pre_init(cls, config: "BaseConfig"):
@@ -188,7 +165,7 @@ class Trainer(LightningTrainer):
188
165
  try:
189
166
  resolved_precision = (
190
167
  "bf16-mixed"
191
- if _is_bf16_supported_no_emulation()
168
+ if is_bf16_supported_no_emulation()
192
169
  else "16-mixed"
193
170
  )
194
171
  except BaseException:
@@ -0,0 +1,25 @@
1
+ import torch
2
+
3
+
4
+ def is_bf16_supported_no_emulation():
5
+ r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
6
+ version = getattr(torch, "version")
7
+
8
+ # Check for ROCm, if true return true, no ROCM_VERSION check required,
9
+ # since it is supported on AMD GPU archs.
10
+ if version.hip:
11
+ return True
12
+
13
+ device = torch.cuda.current_device()
14
+
15
+ # Check for CUDA version and device compute capability.
16
+ # This is a fast way to check for it.
17
+ cuda_version = version.cuda
18
+ if (
19
+ cuda_version is not None
20
+ and int(cuda_version.split(".")[0]) >= 11
21
+ and torch.cuda.get_device_properties(device).major >= 8
22
+ ):
23
+ return True
24
+
25
+ return False
@@ -1,4 +1,5 @@
1
1
  from . import duration as duration
2
+ from .dtype import DTypeConfig as DTypeConfig
2
3
  from .duration import DurationConfig as DurationConfig
3
4
  from .duration import EpochsConfig as EpochsConfig
4
5
  from .duration import StepsConfig as StepsConfig
@@ -0,0 +1,89 @@
1
+ from typing import TYPE_CHECKING, Literal, TypeAlias
2
+
3
+ import nshconfig as C
4
+ import torch
5
+ from typing_extensions import assert_never
6
+
7
+ from ..bf16 import is_bf16_supported_no_emulation
8
+
9
+ if TYPE_CHECKING:
10
+ from ...model.base import BaseConfig
11
+
12
+ DTypeName: TypeAlias = Literal[
13
+ "float32",
14
+ "float",
15
+ "float64",
16
+ "double",
17
+ "float16",
18
+ "bfloat16",
19
+ "float8_e4m3fn",
20
+ "float8_e4m3fnuz",
21
+ "float8_e5m2",
22
+ "float8_e5m2fnuz",
23
+ "half",
24
+ "uint8",
25
+ "uint16",
26
+ "uint32",
27
+ "uint64",
28
+ "int8",
29
+ "int16",
30
+ "short",
31
+ "int32",
32
+ "int",
33
+ "int64",
34
+ "long",
35
+ "complex32",
36
+ "complex64",
37
+ "chalf",
38
+ "cfloat",
39
+ "complex128",
40
+ "cdouble",
41
+ "quint8",
42
+ "qint8",
43
+ "qint32",
44
+ "bool",
45
+ "quint4x2",
46
+ "quint2x4",
47
+ "bits1x8",
48
+ "bits2x4",
49
+ "bits4x2",
50
+ "bits8",
51
+ "bits16",
52
+ ]
53
+
54
+
55
+ class DTypeConfig(C.Config):
56
+ name: DTypeName
57
+ """The name of the dtype."""
58
+
59
+ @classmethod
60
+ def from_base_config(cls, config: "BaseConfig"):
61
+ if (precision := config.trainer.precision) is None:
62
+ precision = "32-true"
63
+
64
+ match precision:
65
+ case "16-mixed-auto":
66
+ return (
67
+ cls(name="bfloat16")
68
+ if is_bf16_supported_no_emulation()
69
+ else cls(name="float16")
70
+ )
71
+ case "fp16-mixed":
72
+ return cls(name="float16")
73
+ case "bf16-mixed":
74
+ return cls(name="bfloat16")
75
+ case "32-true":
76
+ return cls(name="float32")
77
+ case "64-true":
78
+ return cls(name="float64")
79
+ case _:
80
+ assert_never(config.trainer.precision)
81
+
82
+ @property
83
+ def torch_dtype(self):
84
+ if ((dtype := getattr(torch, self.name, None)) is None) or not isinstance(
85
+ dtype, torch.dtype
86
+ ):
87
+ raise ValueError(f"Unknown dtype {self.name}")
88
+
89
+ return dtype
File without changes