nshtrainer 0.33.1__tar.gz → 0.34.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/PKG-INFO +1 -1
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/pyproject.toml +1 -1
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/config.py +35 -36
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/data/balanced_batch_sampler.py +30 -26
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/trainer/trainer.py +2 -25
- nshtrainer-0.34.0/src/nshtrainer/util/bf16.py +25 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/config/__init__.py +1 -0
- nshtrainer-0.34.0/src/nshtrainer/util/config/dtype.py +89 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/README.md +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/_callback.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/_checkpoint/loader.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/_checkpoint/metadata.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/_checkpoint/saver.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/_directory.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/_experimental/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/_hf_hub.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/actsave.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/base.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/checkpoint/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/checkpoint/_base.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/checkpoint/on_exception_checkpoint.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/debug_flag.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/directory_setup.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/early_stopping.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/ema.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/finite_checks.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/interval.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/log_epoch.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/norm_logging.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/print_table.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/rlp_sanity_checks.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/shared_parameters.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/timer.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/data/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/data/transform.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/_experimental.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/actsave.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/callbacks.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/config.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/data.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/log.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/lr_scheduler.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/model.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/nn.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/optimizer.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/runner.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/snapshot.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/snoop.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/trainer.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/typecheck.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/ll/util.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/loggers/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/loggers/_base.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/loggers/csv.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/loggers/tensorboard.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/loggers/wandb.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/lr_scheduler/_base.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/metrics/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/metrics/_config.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/model/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/model/base.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/model/config.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/model/mixins/logger.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/nn/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/nn/mlp.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/nn/module_dict.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/nn/module_list.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/nn/nonlinearity.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/optimizer.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/profiler/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/profiler/_base.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/profiler/advanced.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/profiler/pytorch.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/profiler/simple.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/runner.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/scripts/find_packages.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/trainer/__init__.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/trainer/_config.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/trainer/signal_connector.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/_environment_info.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/_useful_types.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/config/duration.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/environment.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/path.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/seed.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/slurm.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/typed.py +0 -0
- {nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/util/typing_utils.py +0 -0
|
@@ -1,9 +1,21 @@
|
|
|
1
1
|
from nshconfig._config import Config as Config
|
|
2
2
|
from nshsnap._config import SnapshotConfig as SnapshotConfig
|
|
3
3
|
|
|
4
|
+
from nshtrainer._checkpoint.loader import (
|
|
5
|
+
BestCheckpointStrategyConfig as BestCheckpointStrategyConfig,
|
|
6
|
+
)
|
|
4
7
|
from nshtrainer._checkpoint.loader import (
|
|
5
8
|
CheckpointLoadingConfig as CheckpointLoadingConfig,
|
|
6
9
|
)
|
|
10
|
+
from nshtrainer._checkpoint.loader import (
|
|
11
|
+
CheckpointLoadingStrategyConfig as CheckpointLoadingStrategyConfig,
|
|
12
|
+
)
|
|
13
|
+
from nshtrainer._checkpoint.loader import (
|
|
14
|
+
LastCheckpointStrategyConfig as LastCheckpointStrategyConfig,
|
|
15
|
+
)
|
|
16
|
+
from nshtrainer._checkpoint.loader import (
|
|
17
|
+
UserProvidedPathCheckpointStrategyConfig as UserProvidedPathCheckpointStrategyConfig,
|
|
18
|
+
)
|
|
7
19
|
from nshtrainer._checkpoint.metadata import CheckpointMetadata as CheckpointMetadata
|
|
8
20
|
from nshtrainer._directory import DirectoryConfig as DirectoryConfig
|
|
9
21
|
from nshtrainer._hf_hub import (
|
|
@@ -53,48 +65,13 @@ from nshtrainer.callbacks.throughput_monitor import (
|
|
|
53
65
|
)
|
|
54
66
|
from nshtrainer.callbacks.timer import EpochTimerConfig as EpochTimerConfig
|
|
55
67
|
from nshtrainer.callbacks.wandb_watch import WandbWatchConfig as WandbWatchConfig
|
|
56
|
-
from nshtrainer.config import
|
|
57
|
-
from nshtrainer.config import UUID3 as UUID3
|
|
58
|
-
from nshtrainer.config import UUID4 as UUID4
|
|
59
|
-
from nshtrainer.config import UUID5 as UUID5
|
|
60
|
-
from nshtrainer.config import AmqpDsn as AmqpDsn
|
|
61
|
-
from nshtrainer.config import AnyHttpUrl as AnyHttpUrl
|
|
62
|
-
from nshtrainer.config import AnyWebsocketUrl as AnyWebsocketUrl
|
|
63
|
-
from nshtrainer.config import Base64Bytes as Base64Bytes
|
|
64
|
-
from nshtrainer.config import Base64Str as Base64Str
|
|
65
|
-
from nshtrainer.config import Base64UrlBytes as Base64UrlBytes
|
|
66
|
-
from nshtrainer.config import Base64UrlStr as Base64UrlStr
|
|
67
|
-
from nshtrainer.config import ClickHouseDsn as ClickHouseDsn
|
|
68
|
-
from nshtrainer.config import CockroachDsn as CockroachDsn
|
|
69
|
-
from nshtrainer.config import DirectoryPath as DirectoryPath
|
|
70
|
-
from nshtrainer.config import FilePath as FilePath
|
|
71
|
-
from nshtrainer.config import FileUrl as FileUrl
|
|
72
|
-
from nshtrainer.config import FiniteFloat as FiniteFloat
|
|
73
|
-
from nshtrainer.config import FtpUrl as FtpUrl
|
|
74
|
-
from nshtrainer.config import HttpUrl as HttpUrl
|
|
75
|
-
from nshtrainer.config import KafkaDsn as KafkaDsn
|
|
76
|
-
from nshtrainer.config import MariaDBDsn as MariaDBDsn
|
|
77
|
-
from nshtrainer.config import MongoDsn as MongoDsn
|
|
78
|
-
from nshtrainer.config import MySQLDsn as MySQLDsn
|
|
79
|
-
from nshtrainer.config import NatsDsn as NatsDsn
|
|
80
|
-
from nshtrainer.config import NewPath as NewPath
|
|
81
|
-
from nshtrainer.config import OnErrorOmit as OnErrorOmit
|
|
82
|
-
from nshtrainer.config import PostgresDsn as PostgresDsn
|
|
83
|
-
from nshtrainer.config import RedisDsn as RedisDsn
|
|
84
|
-
from nshtrainer.config import SnowflakeDsn as SnowflakeDsn
|
|
85
|
-
from nshtrainer.config import StrictBool as StrictBool
|
|
86
|
-
from nshtrainer.config import StrictBytes as StrictBytes
|
|
87
|
-
from nshtrainer.config import StrictFloat as StrictFloat
|
|
88
|
-
from nshtrainer.config import StrictInt as StrictInt
|
|
89
|
-
from nshtrainer.config import StrictStr as StrictStr
|
|
90
|
-
from nshtrainer.config import WebsocketUrl as WebsocketUrl
|
|
68
|
+
from nshtrainer.config import LRSchedulerConfig as LRSchedulerConfig
|
|
91
69
|
from nshtrainer.loggers._base import BaseLoggerConfig as BaseLoggerConfig
|
|
92
70
|
from nshtrainer.loggers.csv import CSVLoggerConfig as CSVLoggerConfig
|
|
93
71
|
from nshtrainer.loggers.tensorboard import (
|
|
94
72
|
TensorboardLoggerConfig as TensorboardLoggerConfig,
|
|
95
73
|
)
|
|
96
74
|
from nshtrainer.loggers.wandb import WandbLoggerConfig as WandbLoggerConfig
|
|
97
|
-
from nshtrainer.lr_scheduler import LRSchedulerConfig as LRSchedulerConfig
|
|
98
75
|
from nshtrainer.lr_scheduler._base import LRSchedulerConfigBase as LRSchedulerConfigBase
|
|
99
76
|
from nshtrainer.lr_scheduler.linear_warmup_cosine import (
|
|
100
77
|
DurationConfig as DurationConfig,
|
|
@@ -164,9 +141,31 @@ from nshtrainer.util._environment_info import (
|
|
|
164
141
|
EnvironmentClassInformationConfig as EnvironmentClassInformationConfig,
|
|
165
142
|
)
|
|
166
143
|
from nshtrainer.util._environment_info import EnvironmentConfig as EnvironmentConfig
|
|
144
|
+
from nshtrainer.util._environment_info import (
|
|
145
|
+
EnvironmentCUDAConfig as EnvironmentCUDAConfig,
|
|
146
|
+
)
|
|
147
|
+
from nshtrainer.util._environment_info import (
|
|
148
|
+
EnvironmentGPUConfig as EnvironmentGPUConfig,
|
|
149
|
+
)
|
|
150
|
+
from nshtrainer.util._environment_info import (
|
|
151
|
+
EnvironmentHardwareConfig as EnvironmentHardwareConfig,
|
|
152
|
+
)
|
|
167
153
|
from nshtrainer.util._environment_info import (
|
|
168
154
|
EnvironmentLinuxEnvironmentConfig as EnvironmentLinuxEnvironmentConfig,
|
|
169
155
|
)
|
|
156
|
+
from nshtrainer.util._environment_info import (
|
|
157
|
+
EnvironmentLSFInformationConfig as EnvironmentLSFInformationConfig,
|
|
158
|
+
)
|
|
159
|
+
from nshtrainer.util._environment_info import (
|
|
160
|
+
EnvironmentPackageConfig as EnvironmentPackageConfig,
|
|
161
|
+
)
|
|
170
162
|
from nshtrainer.util._environment_info import (
|
|
171
163
|
EnvironmentSLURMInformationConfig as EnvironmentSLURMInformationConfig,
|
|
172
164
|
)
|
|
165
|
+
from nshtrainer.util._environment_info import (
|
|
166
|
+
EnvironmentSnapshotConfig as EnvironmentSnapshotConfig,
|
|
167
|
+
)
|
|
168
|
+
from nshtrainer.util._environment_info import GitRepositoryConfig as GitRepositoryConfig
|
|
169
|
+
from nshtrainer.util.config.dtype import DTypeConfig as DTypeConfig
|
|
170
|
+
from nshtrainer.util.config.duration import EpochsConfig as EpochsConfig
|
|
171
|
+
from nshtrainer.util.config.duration import StepsConfig as StepsConfig
|
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
import heapq
|
|
2
2
|
import logging
|
|
3
|
-
from functools import cached_property
|
|
4
3
|
from typing import Any, Protocol, runtime_checkable
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
import torch
|
|
8
7
|
import torch.distributed
|
|
9
8
|
from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
|
|
10
|
-
from torch.utils.data import BatchSampler,
|
|
9
|
+
from torch.utils.data import BatchSampler, DistributedSampler
|
|
11
10
|
from typing_extensions import override
|
|
12
11
|
|
|
13
12
|
log = logging.getLogger(__name__)
|
|
@@ -47,24 +46,16 @@ class DatasetWithSizes(Protocol):
|
|
|
47
46
|
def data_sizes(self, indices: list[int]) -> np.ndarray: ...
|
|
48
47
|
|
|
49
48
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def
|
|
53
|
-
if not isinstance(dataset, Dataset):
|
|
54
|
-
raise ValueError(
|
|
55
|
-
"BalancedBatchSampler requires a dataset that implements `__getitem__`"
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
if not isinstance(dataset, DatasetWithSizes):
|
|
59
|
-
raise ValueError(
|
|
60
|
-
"BalancedBatchSampler requires a dataset that implements `data_sizes`"
|
|
61
|
-
)
|
|
49
|
+
@runtime_checkable
|
|
50
|
+
class DataSizesFunction(Protocol):
|
|
51
|
+
def __call__(self, dataset: Any, indices: list[int]) -> np.ndarray: ...
|
|
62
52
|
|
|
63
|
-
log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
|
|
64
|
-
return dataset
|
|
65
53
|
|
|
54
|
+
class BalancedBatchSampler(BatchSampler):
|
|
66
55
|
@staticmethod
|
|
67
|
-
def _unwrap_dataset(dataset:
|
|
56
|
+
def _unwrap_dataset(dataset: Any):
|
|
57
|
+
# Lightning's DistributedSampler wraps the dataset in a _DatasetSamplerWrapper,
|
|
58
|
+
# so we need to unwrap it to get the actual dataset.
|
|
68
59
|
if isinstance(dataset, _DatasetSamplerWrapper):
|
|
69
60
|
if (data_source := getattr(dataset._sampler, "data_source", None)) is None:
|
|
70
61
|
raise ValueError("Could not unwrap dataset from _DatasetSamplerWrapper")
|
|
@@ -79,12 +70,6 @@ class BalancedBatchSampler(BatchSampler):
|
|
|
79
70
|
)
|
|
80
71
|
return self.sampler
|
|
81
72
|
|
|
82
|
-
@cached_property
|
|
83
|
-
def dataset(self):
|
|
84
|
-
return self._ensure_supported(
|
|
85
|
-
self._unwrap_dataset(self.distributed_sampler.dataset)
|
|
86
|
-
)
|
|
87
|
-
|
|
88
73
|
def __init__(
|
|
89
74
|
self,
|
|
90
75
|
sampler: DistributedSampler,
|
|
@@ -92,10 +77,12 @@ class BalancedBatchSampler(BatchSampler):
|
|
|
92
77
|
batch_size: int,
|
|
93
78
|
device: torch.device,
|
|
94
79
|
drop_last: bool = False,
|
|
80
|
+
data_sizes_fn: DataSizesFunction | None = None,
|
|
95
81
|
):
|
|
96
82
|
super().__init__(sampler, batch_size, drop_last=drop_last)
|
|
97
83
|
|
|
98
84
|
self._device = device
|
|
85
|
+
self._data_sizes_fn = data_sizes_fn
|
|
99
86
|
|
|
100
87
|
log.info(
|
|
101
88
|
f"Created BalancedBatchSampler with {sampler=}, {batch_size=}, {drop_last=}"
|
|
@@ -105,17 +92,34 @@ class BalancedBatchSampler(BatchSampler):
|
|
|
105
92
|
def _dist_enabled():
|
|
106
93
|
return torch.distributed.is_available() and torch.distributed.is_initialized()
|
|
107
94
|
|
|
95
|
+
def _dataset_sizes(self, indices: list[int]) -> np.ndarray:
|
|
96
|
+
dataset = self._unwrap_dataset(self.distributed_sampler.dataset)
|
|
97
|
+
# Dataset much either implement `data_sizes`, or we need to provide a custom
|
|
98
|
+
# implementation of the dataset sizes function.
|
|
99
|
+
if isinstance(dataset, DatasetWithSizes):
|
|
100
|
+
log.critical(f"BalancedBatchSampler: Resolved dataset to {type(dataset)}")
|
|
101
|
+
return dataset.data_sizes(indices)
|
|
102
|
+
|
|
103
|
+
if (data_sizes_fn := self._data_sizes_fn) is not None:
|
|
104
|
+
return data_sizes_fn(dataset, indices)
|
|
105
|
+
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"Dataset must implement the `data_sizes` method, "
|
|
108
|
+
"or a custom data_sizes_fn must be provided "
|
|
109
|
+
"to the BalancedBatchSampler."
|
|
110
|
+
)
|
|
111
|
+
|
|
108
112
|
@override
|
|
109
113
|
def __iter__(self):
|
|
110
114
|
if not self._dist_enabled():
|
|
111
115
|
yield from super().__iter__()
|
|
112
116
|
return
|
|
113
117
|
|
|
114
|
-
for
|
|
115
|
-
sizes = self.
|
|
118
|
+
for batch_idxs in super().__iter__():
|
|
119
|
+
sizes = self._dataset_sizes(batch_idxs)
|
|
116
120
|
idx_sizes = torch.stack(
|
|
117
121
|
[
|
|
118
|
-
torch.tensor(
|
|
122
|
+
torch.tensor(batch_idxs, device=self._device),
|
|
119
123
|
torch.tensor(sizes, device=self._device),
|
|
120
124
|
]
|
|
121
125
|
)
|
|
@@ -18,6 +18,7 @@ from typing_extensions import Unpack, assert_never, override
|
|
|
18
18
|
|
|
19
19
|
from .._checkpoint.metadata import _write_checkpoint_metadata
|
|
20
20
|
from ..callbacks.base import resolve_all_callbacks
|
|
21
|
+
from ..util.bf16 import is_bf16_supported_no_emulation
|
|
21
22
|
from ._config import (
|
|
22
23
|
AcceleratorConfigProtocol,
|
|
23
24
|
LightningTrainerKwargs,
|
|
@@ -33,30 +34,6 @@ if TYPE_CHECKING:
|
|
|
33
34
|
log = logging.getLogger(__name__)
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
def _is_bf16_supported_no_emulation():
|
|
37
|
-
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
|
|
38
|
-
version = getattr(torch, "version")
|
|
39
|
-
|
|
40
|
-
# Check for ROCm, if true return true, no ROCM_VERSION check required,
|
|
41
|
-
# since it is supported on AMD GPU archs.
|
|
42
|
-
if version.hip:
|
|
43
|
-
return True
|
|
44
|
-
|
|
45
|
-
device = torch.cuda.current_device()
|
|
46
|
-
|
|
47
|
-
# Check for CUDA version and device compute capability.
|
|
48
|
-
# This is a fast way to check for it.
|
|
49
|
-
cuda_version = version.cuda
|
|
50
|
-
if (
|
|
51
|
-
cuda_version is not None
|
|
52
|
-
and int(cuda_version.split(".")[0]) >= 11
|
|
53
|
-
and torch.cuda.get_device_properties(device).major >= 8
|
|
54
|
-
):
|
|
55
|
-
return True
|
|
56
|
-
|
|
57
|
-
return False
|
|
58
|
-
|
|
59
|
-
|
|
60
37
|
class Trainer(LightningTrainer):
|
|
61
38
|
@classmethod
|
|
62
39
|
def _pre_init(cls, config: "BaseConfig"):
|
|
@@ -188,7 +165,7 @@ class Trainer(LightningTrainer):
|
|
|
188
165
|
try:
|
|
189
166
|
resolved_precision = (
|
|
190
167
|
"bf16-mixed"
|
|
191
|
-
if
|
|
168
|
+
if is_bf16_supported_no_emulation()
|
|
192
169
|
else "16-mixed"
|
|
193
170
|
)
|
|
194
171
|
except BaseException:
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def is_bf16_supported_no_emulation():
|
|
5
|
+
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
|
|
6
|
+
version = getattr(torch, "version")
|
|
7
|
+
|
|
8
|
+
# Check for ROCm, if true return true, no ROCM_VERSION check required,
|
|
9
|
+
# since it is supported on AMD GPU archs.
|
|
10
|
+
if version.hip:
|
|
11
|
+
return True
|
|
12
|
+
|
|
13
|
+
device = torch.cuda.current_device()
|
|
14
|
+
|
|
15
|
+
# Check for CUDA version and device compute capability.
|
|
16
|
+
# This is a fast way to check for it.
|
|
17
|
+
cuda_version = version.cuda
|
|
18
|
+
if (
|
|
19
|
+
cuda_version is not None
|
|
20
|
+
and int(cuda_version.split(".")[0]) >= 11
|
|
21
|
+
and torch.cuda.get_device_properties(device).major >= 8
|
|
22
|
+
):
|
|
23
|
+
return True
|
|
24
|
+
|
|
25
|
+
return False
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Literal, TypeAlias
|
|
2
|
+
|
|
3
|
+
import nshconfig as C
|
|
4
|
+
import torch
|
|
5
|
+
from typing_extensions import assert_never
|
|
6
|
+
|
|
7
|
+
from ..bf16 import is_bf16_supported_no_emulation
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from ...model.base import BaseConfig
|
|
11
|
+
|
|
12
|
+
DTypeName: TypeAlias = Literal[
|
|
13
|
+
"float32",
|
|
14
|
+
"float",
|
|
15
|
+
"float64",
|
|
16
|
+
"double",
|
|
17
|
+
"float16",
|
|
18
|
+
"bfloat16",
|
|
19
|
+
"float8_e4m3fn",
|
|
20
|
+
"float8_e4m3fnuz",
|
|
21
|
+
"float8_e5m2",
|
|
22
|
+
"float8_e5m2fnuz",
|
|
23
|
+
"half",
|
|
24
|
+
"uint8",
|
|
25
|
+
"uint16",
|
|
26
|
+
"uint32",
|
|
27
|
+
"uint64",
|
|
28
|
+
"int8",
|
|
29
|
+
"int16",
|
|
30
|
+
"short",
|
|
31
|
+
"int32",
|
|
32
|
+
"int",
|
|
33
|
+
"int64",
|
|
34
|
+
"long",
|
|
35
|
+
"complex32",
|
|
36
|
+
"complex64",
|
|
37
|
+
"chalf",
|
|
38
|
+
"cfloat",
|
|
39
|
+
"complex128",
|
|
40
|
+
"cdouble",
|
|
41
|
+
"quint8",
|
|
42
|
+
"qint8",
|
|
43
|
+
"qint32",
|
|
44
|
+
"bool",
|
|
45
|
+
"quint4x2",
|
|
46
|
+
"quint2x4",
|
|
47
|
+
"bits1x8",
|
|
48
|
+
"bits2x4",
|
|
49
|
+
"bits4x2",
|
|
50
|
+
"bits8",
|
|
51
|
+
"bits16",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class DTypeConfig(C.Config):
|
|
56
|
+
name: DTypeName
|
|
57
|
+
"""The name of the dtype."""
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def from_base_config(cls, config: "BaseConfig"):
|
|
61
|
+
if (precision := config.trainer.precision) is None:
|
|
62
|
+
precision = "32-true"
|
|
63
|
+
|
|
64
|
+
match precision:
|
|
65
|
+
case "16-mixed-auto":
|
|
66
|
+
return (
|
|
67
|
+
cls(name="bfloat16")
|
|
68
|
+
if is_bf16_supported_no_emulation()
|
|
69
|
+
else cls(name="float16")
|
|
70
|
+
)
|
|
71
|
+
case "fp16-mixed":
|
|
72
|
+
return cls(name="float16")
|
|
73
|
+
case "bf16-mixed":
|
|
74
|
+
return cls(name="bfloat16")
|
|
75
|
+
case "32-true":
|
|
76
|
+
return cls(name="float32")
|
|
77
|
+
case "64-true":
|
|
78
|
+
return cls(name="float64")
|
|
79
|
+
case _:
|
|
80
|
+
assert_never(config.trainer.precision)
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def torch_dtype(self):
|
|
84
|
+
if ((dtype := getattr(torch, self.name, None)) is None) or not isinstance(
|
|
85
|
+
dtype, torch.dtype
|
|
86
|
+
):
|
|
87
|
+
raise ValueError(f"Unknown dtype {self.name}")
|
|
88
|
+
|
|
89
|
+
return dtype
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/_throughput_monitor_callback.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/checkpoint/best_checkpoint.py
RENAMED
|
File without changes
|
{nshtrainer-0.33.1 → nshtrainer-0.34.0}/src/nshtrainer/callbacks/checkpoint/last_checkpoint.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|