fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
fkat/__init__.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import asyncio
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
|
|
8
|
+
import hydra
|
|
9
|
+
import omegaconf as oc
|
|
10
|
+
import lightning as L
|
|
11
|
+
import torch
|
|
12
|
+
import torch.multiprocessing as mp
|
|
13
|
+
from torch.distributed.elastic.multiprocessing.errors import record
|
|
14
|
+
|
|
15
|
+
from fkat.utils import config, pdb
|
|
16
|
+
from fkat.utils.logging import rank0_logger
|
|
17
|
+
from fkat.utils.config import SingletonResolver
|
|
18
|
+
|
|
19
|
+
log = rank0_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def run_main(main: Callable[[], None]) -> None:
|
|
23
|
+
patch_args()
|
|
24
|
+
|
|
25
|
+
@record
|
|
26
|
+
async def async_main() -> None:
|
|
27
|
+
try:
|
|
28
|
+
main()
|
|
29
|
+
except Exception as e:
|
|
30
|
+
import traceback
|
|
31
|
+
|
|
32
|
+
traceback.print_tb(e.__traceback__)
|
|
33
|
+
raise e
|
|
34
|
+
|
|
35
|
+
asyncio.run(async_main()) # type: ignore[arg-type]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def patch_args() -> None:
|
|
39
|
+
"""
|
|
40
|
+
In case we need to pass wildcard arguments (e.g. overrides) as expected by Hydra,
|
|
41
|
+
but the runtime only allows named arguments we pass them using a bogus "--overrides" flag.
|
|
42
|
+
This function will take care of removing this flag by the time we call Hydra.
|
|
43
|
+
"""
|
|
44
|
+
overrides_pos = -1
|
|
45
|
+
for i, a in enumerate(sys.argv):
|
|
46
|
+
if a == "--overrides":
|
|
47
|
+
overrides_pos = i
|
|
48
|
+
break
|
|
49
|
+
if overrides_pos >= 0:
|
|
50
|
+
overrides = sys.argv[overrides_pos + 1] if overrides_pos + 1 < len(sys.argv) else ""
|
|
51
|
+
# skipping overrides when constructing new args, there could be more args ahead
|
|
52
|
+
sys.argv = sys.argv[:overrides_pos] + (
|
|
53
|
+
sys.argv[overrides_pos + 2 :] if overrides_pos + 2 < len(sys.argv) else []
|
|
54
|
+
)
|
|
55
|
+
if overrides:
|
|
56
|
+
# adding overrides at the end
|
|
57
|
+
sys.argv.extend(overrides.split(" "))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def setup(
|
|
61
|
+
cfg: oc.DictConfig | None = None,
|
|
62
|
+
print_config: bool = False,
|
|
63
|
+
multiprocessing: str = "spawn",
|
|
64
|
+
seed: int | None = None,
|
|
65
|
+
post_mortem: bool = False,
|
|
66
|
+
determinism: bool = False,
|
|
67
|
+
resolvers: dict[str, "oc.Resolver"] | None = None,
|
|
68
|
+
) -> SingletonResolver:
|
|
69
|
+
"""Setup the training environment.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
cfg (oc.OmegaConf | None): Full configuration
|
|
73
|
+
print_config (bool): Whether to print configuration to output. Defaults to ``False``
|
|
74
|
+
multiprocessing (str): Multiprocessing mode. Defaults to ``spawn``
|
|
75
|
+
seed (int | None): Random number generator seed to start off when set
|
|
76
|
+
post_mortem (bool): Whether to start pdb debugger when an uncaught exception encoutered. Defaults to ``False``
|
|
77
|
+
determinism (bool): Whether to enforce deterministric algorithms. Defaults to ``False```
|
|
78
|
+
resolvers (dict[str, oc.Resover] | None): Custom resolvers to register for configuration processing
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
:class:`SingletonResolver` object that holds initialized data, trainer, model, etc.
|
|
82
|
+
"""
|
|
83
|
+
if print_config:
|
|
84
|
+
log.info(config.to_str(cfg))
|
|
85
|
+
|
|
86
|
+
mp.set_start_method(multiprocessing, force=True)
|
|
87
|
+
|
|
88
|
+
if seed:
|
|
89
|
+
L.seed_everything(seed)
|
|
90
|
+
|
|
91
|
+
if post_mortem:
|
|
92
|
+
pdb.post_mortem()
|
|
93
|
+
|
|
94
|
+
if determinism: # Enable deterministic algorithms globally
|
|
95
|
+
assert seed is not None, "seed has to be set for deterministic runs"
|
|
96
|
+
torch.use_deterministic_algorithms(True)
|
|
97
|
+
if torch.cuda.is_available():
|
|
98
|
+
torch.backends.cudnn.deterministic = True
|
|
99
|
+
torch.backends.cudnn.benchmark = False
|
|
100
|
+
deterministic_env_vars = {
|
|
101
|
+
"CUBLAS_WORKSPACE_CONFIG": [":16:8", ":4096:8"],
|
|
102
|
+
"NCCL_ALGO": ["^NVLS"],
|
|
103
|
+
"NVTE_ALLOW_NONDETERMINISTIC_ALGO": ["0"],
|
|
104
|
+
}
|
|
105
|
+
for var, vals in deterministic_env_vars.items():
|
|
106
|
+
if (val := os.environ.get(var, vals[-1])) not in vals:
|
|
107
|
+
raise ValueError(f"{var} has to be set to one of {vals} for deterministic runs, got: {val}")
|
|
108
|
+
os.environ[var] = val
|
|
109
|
+
for rn, fn in (resolvers or {}).items():
|
|
110
|
+
oc.OmegaConf.register_new_resolver(rn, fn, replace=True)
|
|
111
|
+
|
|
112
|
+
s = config.register_singleton_resolver()
|
|
113
|
+
return s
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def initialize(cfg: oc.DictConfig) -> SingletonResolver:
|
|
117
|
+
"""Initialize data, model and trainer with supplied configurations.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
cfg (oc.DictConfig): Configurations supplied by user through yaml file.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
:class:`SingletonResolver` object that holds initialized data, trainer, model, etc.
|
|
124
|
+
"""
|
|
125
|
+
# 0. setup the training environment
|
|
126
|
+
s = setup(cfg, **(hydra.utils.instantiate(cfg["setup"]) if "setup" in cfg else {}))
|
|
127
|
+
|
|
128
|
+
# 1. instantiate `trainer`
|
|
129
|
+
s.trainer = hydra.utils.instantiate(cfg.trainer)
|
|
130
|
+
|
|
131
|
+
# 2. instantiate optional `data`
|
|
132
|
+
s.data = hydra.utils.instantiate(cfg.get("data"))
|
|
133
|
+
|
|
134
|
+
# 3. instantiate `model` after `trainer`
|
|
135
|
+
s.model = hydra.utils.instantiate(cfg.model)
|
|
136
|
+
|
|
137
|
+
# 4. obtain optional `ckpt_path`, `return_predictions` after `model`
|
|
138
|
+
s.ckpt_path = hydra.utils.call(cfg.get("ckpt_path"))
|
|
139
|
+
s.return_predictions = hydra.utils.call(cfg.get("return_predictions"))
|
|
140
|
+
|
|
141
|
+
# 5. save and upload the config
|
|
142
|
+
config.save(cfg, s.trainer)
|
|
143
|
+
|
|
144
|
+
# 6. run tuners
|
|
145
|
+
s.tuners = hydra.utils.instantiate(cfg.get("tuners"))
|
|
146
|
+
|
|
147
|
+
return s
|
fkat/data/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from .data_module import DataModule, PersistStates, RestoreStates
|
|
4
|
+
from .shm import ShmDataLoader
|
|
5
|
+
from .sharded import ShardedDataLoader
|
|
6
|
+
from .dict import DictDataLoader
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"DataModule",
|
|
10
|
+
"ShmDataLoader",
|
|
11
|
+
"ShardedDataLoader",
|
|
12
|
+
"DictDataLoader",
|
|
13
|
+
"PersistStates",
|
|
14
|
+
"RestoreStates",
|
|
15
|
+
]
|
fkat/data/data_module.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import logging
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Any
|
|
6
|
+
from collections.abc import Callable, Iterable
|
|
7
|
+
|
|
8
|
+
import lightning as L
|
|
9
|
+
from lightning.pytorch.profilers import Profiler
|
|
10
|
+
from lightning.pytorch.core.hooks import CheckpointHooks
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
from typing_extensions import Protocol, override, runtime_checkable
|
|
13
|
+
|
|
14
|
+
from fkat.utils.rng import get_rng_states, set_rng_states
|
|
15
|
+
from fkat.utils.profiler import profile_until_exit
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _call(dataloader: Iterable[Any] | None, attr: str, *args: Any, **kwargs: Any) -> None:
|
|
21
|
+
if dataloader is None:
|
|
22
|
+
return
|
|
23
|
+
for obj in (
|
|
24
|
+
dataloader,
|
|
25
|
+
getattr(dataloader, "dataset", None),
|
|
26
|
+
getattr(dataloader, "sampler", None),
|
|
27
|
+
getattr(dataloader, "batch_sampler", None),
|
|
28
|
+
):
|
|
29
|
+
if not obj:
|
|
30
|
+
continue
|
|
31
|
+
if impl := getattr(obj, attr, None):
|
|
32
|
+
impl(*args, **kwargs)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def worker_init_fn(
|
|
36
|
+
profiler: Profiler,
|
|
37
|
+
stage: str,
|
|
38
|
+
init_fn: Callable[[int], Any] | None,
|
|
39
|
+
worker_id: int,
|
|
40
|
+
) -> None:
|
|
41
|
+
action = f"DataWorker[{stage}][{worker_id}]"
|
|
42
|
+
profile_until_exit(profiler, action=action, filename_suffix=f"_{stage}_{worker_id}")
|
|
43
|
+
|
|
44
|
+
if init_fn is not None:
|
|
45
|
+
init_fn(worker_id)
|
|
46
|
+
|
|
47
|
+
# TODO: Add Dataloader worker have consistent seed based on worker_id
|
|
48
|
+
# Reference: https://github.com/pytorch/pytorch/issues/5059#issuecomment-404232359
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def instrument(
|
|
52
|
+
cfg: dict[str, Any],
|
|
53
|
+
profiler: Profiler | None,
|
|
54
|
+
stage: str,
|
|
55
|
+
) -> dict[str, Any]:
|
|
56
|
+
if not profiler:
|
|
57
|
+
return cfg
|
|
58
|
+
cfg["worker_init_fn"] = partial(worker_init_fn, profiler, stage, cfg.get("worker_init_fn"))
|
|
59
|
+
return cfg
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@runtime_checkable
|
|
63
|
+
class PersistStates(Protocol):
|
|
64
|
+
def state_dict(self) -> dict[str, Any]: ...
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@runtime_checkable
|
|
68
|
+
class RestoreStates(Protocol):
|
|
69
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None: ...
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class DataModule(L.LightningDataModule, CheckpointHooks):
|
|
73
|
+
"""A :class:`LightningDataModule` that manages multiple :class:`DataLoader`\\s for different stages.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
dataloaders (dict[str, dict[str, Any] | Callable[[], Iterable[Any]]]): Dataloaders for different stages.
|
|
77
|
+
profiler (Profiler | None): Profiler instance for worker initialization.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
SUPPORTED_STAGES = ("train", "test", "val", "predict")
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
dataloaders: dict[str, dict[str, Any] | Callable[[], Iterable[Any]]],
|
|
85
|
+
profiler: Profiler | None = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
super().__init__()
|
|
88
|
+
self.profiler = profiler
|
|
89
|
+
self.dataloader_factory: dict[str, Callable[[], Iterable[Any]]] = {}
|
|
90
|
+
for stage, cfg in dataloaders.items():
|
|
91
|
+
if stage not in DataModule.SUPPORTED_STAGES:
|
|
92
|
+
raise ValueError(f"Unsupported stage {stage}, use one of {DataModule.SUPPORTED_STAGES}")
|
|
93
|
+
dataloader_factory: Callable[[], Iterable[Any]]
|
|
94
|
+
if isinstance(cfg, dict):
|
|
95
|
+
cfg = instrument(cfg, profiler, stage) # type: ignore[arg-type]
|
|
96
|
+
dataloader_factory = partial(DataLoader, **cfg)
|
|
97
|
+
else:
|
|
98
|
+
dataloader_factory = cfg
|
|
99
|
+
self.dataloader_factory[stage] = dataloader_factory
|
|
100
|
+
self.dataloaders: dict[str, Iterable[Any] | None] = {}
|
|
101
|
+
|
|
102
|
+
def _new_dataloader(self, stage: str) -> Iterable[Any] | None:
|
|
103
|
+
dataloader_factory = self.dataloader_factory.get(stage, lambda: None)
|
|
104
|
+
self.dataloaders[stage] = (dataloader := dataloader_factory())
|
|
105
|
+
return dataloader
|
|
106
|
+
|
|
107
|
+
@override
|
|
108
|
+
def prepare_data(self) -> None:
|
|
109
|
+
for dataloader in self.dataloaders.values():
|
|
110
|
+
_call(dataloader, "prepare_data")
|
|
111
|
+
|
|
112
|
+
def _dataloader(self, stage: str) -> Iterable[Any] | None:
|
|
113
|
+
stage = "train" if stage == "fit" else "val" if stage == "validation" else stage
|
|
114
|
+
return self.dataloaders.get(stage)
|
|
115
|
+
|
|
116
|
+
@override
|
|
117
|
+
def setup(self, stage: str) -> None:
|
|
118
|
+
device = self.trainer and self.trainer.strategy and self.trainer.strategy.root_device
|
|
119
|
+
_call(self._dataloader(stage), "set_device", device)
|
|
120
|
+
_call(self._dataloader(stage), "setup", stage)
|
|
121
|
+
|
|
122
|
+
# will be used once https://github.com/Lightning-AI/pytorch-lightning/pull/19601 is in effect
|
|
123
|
+
def on_exception(self, exception: BaseException) -> None:
|
|
124
|
+
for dataloader in self.dataloaders.values():
|
|
125
|
+
_call(dataloader, "on_exception", exception)
|
|
126
|
+
|
|
127
|
+
@override
|
|
128
|
+
def teardown(self, stage: str | None) -> None:
|
|
129
|
+
# this is it, terminating everything regardless of which stage we received this
|
|
130
|
+
for dataloader in self.dataloaders.values():
|
|
131
|
+
_call(dataloader, "teardown", stage)
|
|
132
|
+
|
|
133
|
+
@override
|
|
134
|
+
def train_dataloader(self) -> Iterable[Any] | None:
|
|
135
|
+
return self._new_dataloader("train")
|
|
136
|
+
|
|
137
|
+
@override
|
|
138
|
+
def val_dataloader(self) -> Iterable[Any] | None:
|
|
139
|
+
return self._new_dataloader("val")
|
|
140
|
+
|
|
141
|
+
@override
|
|
142
|
+
def predict_dataloader(self) -> Iterable[Any] | None:
|
|
143
|
+
return self._new_dataloader("predict")
|
|
144
|
+
|
|
145
|
+
@override
|
|
146
|
+
def test_dataloader(self) -> Iterable[Any] | None:
|
|
147
|
+
return self._new_dataloader("test")
|
|
148
|
+
|
|
149
|
+
@override
|
|
150
|
+
def state_dict(self) -> dict[str, Any]:
|
|
151
|
+
"""Called when saving a checkpoint, implement to generate and save datamodule state for ShardedDataLoader.
|
|
152
|
+
|
|
153
|
+
This method iterates over each stage's dataloader to retrieve its state using the `state_dict()` method,
|
|
154
|
+
and saves it along with the RNG states. If a dataloader does not implement the `PersistStates` protocol,
|
|
155
|
+
it sets its `state_dict` attribute to `vanilla_dataloader_state_dict` to allow saving its state.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
dict[str, Any]: A dictionary containing the dataloader states and RNG states.
|
|
159
|
+
"""
|
|
160
|
+
dataloader_states = {}
|
|
161
|
+
for stage, dataloader in self.dataloaders.items():
|
|
162
|
+
dataloader_states[stage] = get_rng_states()
|
|
163
|
+
if isinstance(dataloader, PersistStates):
|
|
164
|
+
try:
|
|
165
|
+
result_dict = dataloader.state_dict()
|
|
166
|
+
dataloader_states[stage].update(result_dict)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.warning(f"{dataloader} states can't be persisted yet: {e}.")
|
|
169
|
+
return dataloader_states
|
|
170
|
+
|
|
171
|
+
@override
|
|
172
|
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
173
|
+
"""Called when loading a checkpoint to reload the DataModule's state for a ShardedDataLoader.
|
|
174
|
+
|
|
175
|
+
This method iterates over each stage's dataloader, loads its state from the provided `state_dict`,
|
|
176
|
+
and sets the RNG states. If a dataloader does not implement the `RestoreStates` protocol,
|
|
177
|
+
it sets its `load_state_dict` attribute to `vanilla_dataloader_load_state_dict` to allow loading its state.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
state_dict (Dict[str, Any]): A dictionary containing the dataloader states and RNG states.
|
|
181
|
+
"""
|
|
182
|
+
for stage, dataloader in self.dataloaders.items():
|
|
183
|
+
set_rng_states(state_dict[stage])
|
|
184
|
+
if isinstance(dataloader, RestoreStates):
|
|
185
|
+
try:
|
|
186
|
+
dataloader.load_state_dict(state_dict[stage])
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.warning(f"{dataloader} states can't be restored yet: {e}.")
|
|
189
|
+
|
|
190
|
+
@override
|
|
191
|
+
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
|
192
|
+
for dataloader in self.dataloaders.values():
|
|
193
|
+
_call(dataloader, "on_save_checkpoint", checkpoint)
|
|
194
|
+
|
|
195
|
+
@override
|
|
196
|
+
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
|
197
|
+
for dataloader in self.dataloaders.values():
|
|
198
|
+
_call(dataloader, "on_load_checkpoint", checkpoint)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from .sized import SizedDataset # isort: skip
|
|
4
|
+
from .dict import DictDataset
|
|
5
|
+
from .map import MapDataset, IterableMapDataset
|
|
6
|
+
from .json import JsonDataset, IterableJsonDataset
|
|
7
|
+
from .parquet import ParquetDataset, IterableParquetDataset
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"SizedDataset",
|
|
12
|
+
"MapDataset",
|
|
13
|
+
"IterableMapDataset",
|
|
14
|
+
"DictDataset",
|
|
15
|
+
"JsonDataset",
|
|
16
|
+
"IterableJsonDataset",
|
|
17
|
+
"ParquetDataset",
|
|
18
|
+
"IterableParquetDataset",
|
|
19
|
+
]
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from fkat.data.datasets import SizedDataset
|
|
8
|
+
from fkat.utils.config import to_primitive_container
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DictDataset(SizedDataset[tuple[str, Any], dict[str, Any]]):
|
|
12
|
+
""":class:`Dataset` that can get samples from one of the :class:`Dataset` using a mapping."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
datasets: dict[str, SizedDataset[Any, dict[str, Any]]],
|
|
17
|
+
key: str = "dataset",
|
|
18
|
+
) -> None:
|
|
19
|
+
"""Create a :class:`Dataset` that can get samples from one of the :class:`Dataset` using a mapping.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
datasets (Dict[str, SizedDataset[Any, Dict[str, Any]]]): A mapping from labels to :class:`Dataset`\\s.
|
|
23
|
+
key (str): The name of the field to reflect the :class:`Dataset` the samples were provided from.
|
|
24
|
+
Defaults to "dataset".
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
None
|
|
28
|
+
"""
|
|
29
|
+
self.datasets = to_primitive_container(datasets)
|
|
30
|
+
self.len = sum(len(dataset) for dataset in datasets.values())
|
|
31
|
+
self.key = key
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def __len__(self) -> int:
|
|
35
|
+
"""Get :class:`Dataset` size.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
int: :class:`Dataset` size.
|
|
39
|
+
"""
|
|
40
|
+
return self.len
|
|
41
|
+
|
|
42
|
+
def _wrap(self, name: str, item: dict[str, Any]) -> dict[str, Any]:
|
|
43
|
+
if not isinstance(item, dict) or self.key in item:
|
|
44
|
+
raise RuntimeError(f"Datasets must return a dict without {self.key} key")
|
|
45
|
+
item[self.key] = name
|
|
46
|
+
return item
|
|
47
|
+
|
|
48
|
+
def __getitems__(self, name_and_idxs: tuple[str, list[Any]]) -> list[dict[str, Any]]:
|
|
49
|
+
"""Get a batch of samples from the target :class:`Dataset` at the specified indices.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
name_and_idxs (Tuple[str, List[Any]]): Samples' :class:`Dataset` and indices.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
List[Dict[str, Any]]: A batch of samples.
|
|
56
|
+
"""
|
|
57
|
+
name, idxs = name_and_idxs
|
|
58
|
+
if getitems := getattr(self.datasets[name], "__getitems__", None):
|
|
59
|
+
batch = getitems(idxs)
|
|
60
|
+
else:
|
|
61
|
+
batch = [self.datasets[name][idx] for idx in idxs]
|
|
62
|
+
for b in batch:
|
|
63
|
+
self._wrap(name, b)
|
|
64
|
+
return batch
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
def __getitem__(self, idx: tuple[str, Any]) -> dict[str, Any]:
|
|
68
|
+
"""Get a sample from the target :class:`Dataset` at the specified index.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
idx (Tuple[str, Any]): Sample :class:`Dataset` and index.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dict[str, Any]: A sample.
|
|
75
|
+
"""
|
|
76
|
+
name, idx_ = idx
|
|
77
|
+
sample = self.datasets[name][idx_]
|
|
78
|
+
return self._wrap(name, sample)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import awswrangler as s3wr
|
|
4
|
+
from typing import Any
|
|
5
|
+
from collections.abc import Iterator
|
|
6
|
+
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
import pyarrow.json as pj
|
|
9
|
+
from pyarrow.fs import FileSystem, S3FileSystem # type: ignore[possibly-unbound-import]
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import numpy as np
|
|
12
|
+
from fkat.data.datasets import SizedDataset
|
|
13
|
+
from fkat.utils.pyarrow import iter_rows as pa_iter_rows
|
|
14
|
+
from fkat.utils.pandas import iter_rows as pd_iter_rows
|
|
15
|
+
from fkat.utils.boto3 import session
|
|
16
|
+
from torch.utils.data import IterableDataset
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class IterableJsonDataset(IterableDataset[dict[str, Any]]):
|
|
20
|
+
"""
|
|
21
|
+
An :class:`IterableDataset` backed by Json data.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
uri (str | list[str]): URI of Parquet data.
|
|
25
|
+
read_options: pyarrow.json.ReadOptions, optional
|
|
26
|
+
Options for the JSON reader (see ReadOptions constructor for defaults).
|
|
27
|
+
parse_options: pyarrow.json.ParseOptions, optional
|
|
28
|
+
Options for the JSON parser
|
|
29
|
+
(see ParseOptions constructor for defaults).
|
|
30
|
+
memory_pool: MemoryPool, optional
|
|
31
|
+
Pool to allocate Table memory from.
|
|
32
|
+
chunk_size (int): An iterable of DataFrames is returned with maximum rows equal to the received INTEGER.
|
|
33
|
+
replace_nan (bool): Whether to replace np.nan as None.
|
|
34
|
+
Default to ``True``
|
|
35
|
+
s3wr_args (dict): config for s3wr.s3.read_json,
|
|
36
|
+
refer to https://aws-sdk-pandas.readthedocs.io/en/3.5.1/stubs/awswrangler.s3.read_parquet.html
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
uri: str | list[str],
|
|
43
|
+
read_options: pa.json.ReadOptions | None = None,
|
|
44
|
+
parse_options: pa.json.ParseOptions | None = None,
|
|
45
|
+
memory_pool: pa.MemoryPool | None = None,
|
|
46
|
+
chunk_size: int = 10000,
|
|
47
|
+
replace_nan: bool = True,
|
|
48
|
+
**s3wr_args: Any,
|
|
49
|
+
) -> None:
|
|
50
|
+
fs: FileSystem
|
|
51
|
+
path: str
|
|
52
|
+
if isinstance(uri, str):
|
|
53
|
+
fs, path = FileSystem.from_uri(uri)
|
|
54
|
+
else:
|
|
55
|
+
fs, path = FileSystem.from_uri(uri[0])
|
|
56
|
+
self.s3_file = isinstance(fs, S3FileSystem)
|
|
57
|
+
if self.s3_file:
|
|
58
|
+
self.uri = uri
|
|
59
|
+
self.chunk_size = chunk_size
|
|
60
|
+
self.replace_nan = replace_nan
|
|
61
|
+
self.s3wr_args = s3wr_args
|
|
62
|
+
else:
|
|
63
|
+
assert isinstance(uri, str), "IterableJsonDataset can only accept uri as str"
|
|
64
|
+
with fs.open_input_file(path) as f:
|
|
65
|
+
self.tbl = pj.read_json(
|
|
66
|
+
f, read_options=read_options, parse_options=parse_options, memory_pool=memory_pool
|
|
67
|
+
)
|
|
68
|
+
self.chunk_size = chunk_size
|
|
69
|
+
|
|
70
|
+
def __iter__(self) -> Iterator[dict[str, Any]]:
|
|
71
|
+
"""Creates dataset iterator.
|
|
72
|
+
Returns:
|
|
73
|
+
Iterator[dict[str, Any]]: dataset iterator
|
|
74
|
+
"""
|
|
75
|
+
if self.s3_file:
|
|
76
|
+
return pd_iter_rows(
|
|
77
|
+
s3wr.s3.read_json(
|
|
78
|
+
self.uri,
|
|
79
|
+
lines=True,
|
|
80
|
+
chunksize=self.chunk_size,
|
|
81
|
+
boto3_session=session(clients=["s3"]),
|
|
82
|
+
path_suffix="json",
|
|
83
|
+
**self.s3wr_args,
|
|
84
|
+
),
|
|
85
|
+
self.replace_nan,
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
return pa_iter_rows(self.tbl, chunk_size=self.chunk_size)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class JsonDataset(SizedDataset[int, dict[str, Any]]):
|
|
92
|
+
"""
|
|
93
|
+
Create a :class:`Dataset` from JSON data at the specified URI.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
uri (str | list[str]): URI of JSON data.
|
|
97
|
+
read_options (pa.json.ReadOptions | None): JSON read options.
|
|
98
|
+
parse_options (pa.json.ParseOptions | None): JSON parse options.
|
|
99
|
+
memory_pool (pa.MemoryPool | None): JSON processing memory pool configuration.
|
|
100
|
+
replace_nan (bool): Whether to replace np.nan as None.
|
|
101
|
+
Default to ``True``
|
|
102
|
+
s3wr_args (Any): config for s3wr.s3.read_json,
|
|
103
|
+
refer to https://aws-sdk-pandas.readthedocs.io/en/3.5.1/stubs/awswrangler.s3.read_json.html
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
uri: str | list[str],
|
|
109
|
+
read_options: pa.json.ReadOptions | None = None,
|
|
110
|
+
parse_options: pa.json.ParseOptions | None = None,
|
|
111
|
+
memory_pool: pa.MemoryPool | None = None,
|
|
112
|
+
replace_nan: bool = True,
|
|
113
|
+
**s3wr_args: Any,
|
|
114
|
+
) -> None:
|
|
115
|
+
fs: FileSystem
|
|
116
|
+
path: str
|
|
117
|
+
if isinstance(uri, str):
|
|
118
|
+
fs, path = FileSystem.from_uri(uri)
|
|
119
|
+
else:
|
|
120
|
+
fs, path = FileSystem.from_uri(uri[0])
|
|
121
|
+
if isinstance(fs, S3FileSystem):
|
|
122
|
+
self.df = s3wr.s3.read_json(
|
|
123
|
+
uri, lines=True, boto3_session=session(clients=["s3"]), path_suffix="json", **s3wr_args
|
|
124
|
+
)
|
|
125
|
+
if replace_nan:
|
|
126
|
+
self.df = self.df.replace({np.nan: None})
|
|
127
|
+
else:
|
|
128
|
+
path_list: list[str] = []
|
|
129
|
+
if isinstance(uri, str):
|
|
130
|
+
path_list.append(path)
|
|
131
|
+
else:
|
|
132
|
+
for each in uri:
|
|
133
|
+
_, path = FileSystem.from_uri(each)
|
|
134
|
+
path_list.append(path)
|
|
135
|
+
df = []
|
|
136
|
+
for each in path_list:
|
|
137
|
+
with fs.open_input_file(each) as f:
|
|
138
|
+
tbl = pj.read_json(
|
|
139
|
+
f, read_options=read_options, parse_options=parse_options, memory_pool=memory_pool
|
|
140
|
+
)
|
|
141
|
+
df.append(tbl.to_pandas())
|
|
142
|
+
self.df = pd.concat(df)
|
|
143
|
+
|
|
144
|
+
def __len__(self) -> int:
|
|
145
|
+
"""Get :class:`Dataset` size.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
int: :class:`Dataset` size.
|
|
149
|
+
"""
|
|
150
|
+
return len(self.df)
|
|
151
|
+
|
|
152
|
+
def __getitems__(self, idxs: list[int]) -> list[dict[str, Any]]:
|
|
153
|
+
"""Get a batch of samples at the specified indices.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
idxs (list[int]): Samples' indices.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
list[dict[str, Any]]: A batch of samples.
|
|
160
|
+
"""
|
|
161
|
+
series = self.df.iloc[idxs]
|
|
162
|
+
samples = [series.iloc[i].to_dict() for i in range(len(idxs))]
|
|
163
|
+
return samples
|
|
164
|
+
|
|
165
|
+
def __getitem__(self, idx: int) -> dict[str, Any]:
|
|
166
|
+
"""Get a sample at the specified index.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
idx (int): Sample index.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
dict[str, Any]: A sample.
|
|
173
|
+
"""
|
|
174
|
+
series = self.df.iloc[idx]
|
|
175
|
+
sample = series.to_dict()
|
|
176
|
+
return sample
|