fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
fkat/__init__.py ADDED
@@ -0,0 +1,147 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import sys
5
+ import asyncio
6
+ from collections.abc import Callable
7
+
8
+ import hydra
9
+ import omegaconf as oc
10
+ import lightning as L
11
+ import torch
12
+ import torch.multiprocessing as mp
13
+ from torch.distributed.elastic.multiprocessing.errors import record
14
+
15
+ from fkat.utils import config, pdb
16
+ from fkat.utils.logging import rank0_logger
17
+ from fkat.utils.config import SingletonResolver
18
+
19
+ log = rank0_logger(__name__)
20
+
21
+
22
+ def run_main(main: Callable[[], None]) -> None:
23
+ patch_args()
24
+
25
+ @record
26
+ async def async_main() -> None:
27
+ try:
28
+ main()
29
+ except Exception as e:
30
+ import traceback
31
+
32
+ traceback.print_tb(e.__traceback__)
33
+ raise e
34
+
35
+ asyncio.run(async_main()) # type: ignore[arg-type]
36
+
37
+
38
+ def patch_args() -> None:
39
+ """
40
+ In case we need to pass wildcard arguments (e.g. overrides) as expected by Hydra,
41
+ but the runtime only allows named arguments we pass them using a bogus "--overrides" flag.
42
+ This function will take care of removing this flag by the time we call Hydra.
43
+ """
44
+ overrides_pos = -1
45
+ for i, a in enumerate(sys.argv):
46
+ if a == "--overrides":
47
+ overrides_pos = i
48
+ break
49
+ if overrides_pos >= 0:
50
+ overrides = sys.argv[overrides_pos + 1] if overrides_pos + 1 < len(sys.argv) else ""
51
+ # skipping overrides when constructing new args, there could be more args ahead
52
+ sys.argv = sys.argv[:overrides_pos] + (
53
+ sys.argv[overrides_pos + 2 :] if overrides_pos + 2 < len(sys.argv) else []
54
+ )
55
+ if overrides:
56
+ # adding overrides at the end
57
+ sys.argv.extend(overrides.split(" "))
58
+
59
+
60
+ def setup(
61
+ cfg: oc.DictConfig | None = None,
62
+ print_config: bool = False,
63
+ multiprocessing: str = "spawn",
64
+ seed: int | None = None,
65
+ post_mortem: bool = False,
66
+ determinism: bool = False,
67
+ resolvers: dict[str, "oc.Resolver"] | None = None,
68
+ ) -> SingletonResolver:
69
+ """Setup the training environment.
70
+
71
+ Args:
72
+ cfg (oc.OmegaConf | None): Full configuration
73
+ print_config (bool): Whether to print configuration to output. Defaults to ``False``
74
+ multiprocessing (str): Multiprocessing mode. Defaults to ``spawn``
75
+ seed (int | None): Random number generator seed to start off when set
76
+ post_mortem (bool): Whether to start pdb debugger when an uncaught exception encoutered. Defaults to ``False``
77
+ determinism (bool): Whether to enforce deterministric algorithms. Defaults to ``False```
78
+ resolvers (dict[str, oc.Resover] | None): Custom resolvers to register for configuration processing
79
+
80
+ Returns:
81
+ :class:`SingletonResolver` object that holds initialized data, trainer, model, etc.
82
+ """
83
+ if print_config:
84
+ log.info(config.to_str(cfg))
85
+
86
+ mp.set_start_method(multiprocessing, force=True)
87
+
88
+ if seed:
89
+ L.seed_everything(seed)
90
+
91
+ if post_mortem:
92
+ pdb.post_mortem()
93
+
94
+ if determinism: # Enable deterministic algorithms globally
95
+ assert seed is not None, "seed has to be set for deterministic runs"
96
+ torch.use_deterministic_algorithms(True)
97
+ if torch.cuda.is_available():
98
+ torch.backends.cudnn.deterministic = True
99
+ torch.backends.cudnn.benchmark = False
100
+ deterministic_env_vars = {
101
+ "CUBLAS_WORKSPACE_CONFIG": [":16:8", ":4096:8"],
102
+ "NCCL_ALGO": ["^NVLS"],
103
+ "NVTE_ALLOW_NONDETERMINISTIC_ALGO": ["0"],
104
+ }
105
+ for var, vals in deterministic_env_vars.items():
106
+ if (val := os.environ.get(var, vals[-1])) not in vals:
107
+ raise ValueError(f"{var} has to be set to one of {vals} for deterministic runs, got: {val}")
108
+ os.environ[var] = val
109
+ for rn, fn in (resolvers or {}).items():
110
+ oc.OmegaConf.register_new_resolver(rn, fn, replace=True)
111
+
112
+ s = config.register_singleton_resolver()
113
+ return s
114
+
115
+
116
+ def initialize(cfg: oc.DictConfig) -> SingletonResolver:
117
+ """Initialize data, model and trainer with supplied configurations.
118
+
119
+ Args:
120
+ cfg (oc.DictConfig): Configurations supplied by user through yaml file.
121
+
122
+ Returns:
123
+ :class:`SingletonResolver` object that holds initialized data, trainer, model, etc.
124
+ """
125
+ # 0. setup the training environment
126
+ s = setup(cfg, **(hydra.utils.instantiate(cfg["setup"]) if "setup" in cfg else {}))
127
+
128
+ # 1. instantiate `trainer`
129
+ s.trainer = hydra.utils.instantiate(cfg.trainer)
130
+
131
+ # 2. instantiate optional `data`
132
+ s.data = hydra.utils.instantiate(cfg.get("data"))
133
+
134
+ # 3. instantiate `model` after `trainer`
135
+ s.model = hydra.utils.instantiate(cfg.model)
136
+
137
+ # 4. obtain optional `ckpt_path`, `return_predictions` after `model`
138
+ s.ckpt_path = hydra.utils.call(cfg.get("ckpt_path"))
139
+ s.return_predictions = hydra.utils.call(cfg.get("return_predictions"))
140
+
141
+ # 5. save and upload the config
142
+ config.save(cfg, s.trainer)
143
+
144
+ # 6. run tuners
145
+ s.tuners = hydra.utils.instantiate(cfg.get("tuners"))
146
+
147
+ return s
fkat/data/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from .data_module import DataModule, PersistStates, RestoreStates
4
+ from .shm import ShmDataLoader
5
+ from .sharded import ShardedDataLoader
6
+ from .dict import DictDataLoader
7
+
8
+ __all__ = [
9
+ "DataModule",
10
+ "ShmDataLoader",
11
+ "ShardedDataLoader",
12
+ "DictDataLoader",
13
+ "PersistStates",
14
+ "RestoreStates",
15
+ ]
@@ -0,0 +1,198 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import logging
4
+ from functools import partial
5
+ from typing import Any
6
+ from collections.abc import Callable, Iterable
7
+
8
+ import lightning as L
9
+ from lightning.pytorch.profilers import Profiler
10
+ from lightning.pytorch.core.hooks import CheckpointHooks
11
+ from torch.utils.data import DataLoader
12
+ from typing_extensions import Protocol, override, runtime_checkable
13
+
14
+ from fkat.utils.rng import get_rng_states, set_rng_states
15
+ from fkat.utils.profiler import profile_until_exit
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _call(dataloader: Iterable[Any] | None, attr: str, *args: Any, **kwargs: Any) -> None:
21
+ if dataloader is None:
22
+ return
23
+ for obj in (
24
+ dataloader,
25
+ getattr(dataloader, "dataset", None),
26
+ getattr(dataloader, "sampler", None),
27
+ getattr(dataloader, "batch_sampler", None),
28
+ ):
29
+ if not obj:
30
+ continue
31
+ if impl := getattr(obj, attr, None):
32
+ impl(*args, **kwargs)
33
+
34
+
35
+ def worker_init_fn(
36
+ profiler: Profiler,
37
+ stage: str,
38
+ init_fn: Callable[[int], Any] | None,
39
+ worker_id: int,
40
+ ) -> None:
41
+ action = f"DataWorker[{stage}][{worker_id}]"
42
+ profile_until_exit(profiler, action=action, filename_suffix=f"_{stage}_{worker_id}")
43
+
44
+ if init_fn is not None:
45
+ init_fn(worker_id)
46
+
47
+ # TODO: Add Dataloader worker have consistent seed based on worker_id
48
+ # Reference: https://github.com/pytorch/pytorch/issues/5059#issuecomment-404232359
49
+
50
+
51
+ def instrument(
52
+ cfg: dict[str, Any],
53
+ profiler: Profiler | None,
54
+ stage: str,
55
+ ) -> dict[str, Any]:
56
+ if not profiler:
57
+ return cfg
58
+ cfg["worker_init_fn"] = partial(worker_init_fn, profiler, stage, cfg.get("worker_init_fn"))
59
+ return cfg
60
+
61
+
62
+ @runtime_checkable
63
+ class PersistStates(Protocol):
64
+ def state_dict(self) -> dict[str, Any]: ...
65
+
66
+
67
+ @runtime_checkable
68
+ class RestoreStates(Protocol):
69
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None: ...
70
+
71
+
72
+ class DataModule(L.LightningDataModule, CheckpointHooks):
73
+ """A :class:`LightningDataModule` that manages multiple :class:`DataLoader`\\s for different stages.
74
+
75
+ Args:
76
+ dataloaders (dict[str, dict[str, Any] | Callable[[], Iterable[Any]]]): Dataloaders for different stages.
77
+ profiler (Profiler | None): Profiler instance for worker initialization.
78
+ """
79
+
80
+ SUPPORTED_STAGES = ("train", "test", "val", "predict")
81
+
82
+ def __init__(
83
+ self,
84
+ dataloaders: dict[str, dict[str, Any] | Callable[[], Iterable[Any]]],
85
+ profiler: Profiler | None = None,
86
+ ) -> None:
87
+ super().__init__()
88
+ self.profiler = profiler
89
+ self.dataloader_factory: dict[str, Callable[[], Iterable[Any]]] = {}
90
+ for stage, cfg in dataloaders.items():
91
+ if stage not in DataModule.SUPPORTED_STAGES:
92
+ raise ValueError(f"Unsupported stage {stage}, use one of {DataModule.SUPPORTED_STAGES}")
93
+ dataloader_factory: Callable[[], Iterable[Any]]
94
+ if isinstance(cfg, dict):
95
+ cfg = instrument(cfg, profiler, stage) # type: ignore[arg-type]
96
+ dataloader_factory = partial(DataLoader, **cfg)
97
+ else:
98
+ dataloader_factory = cfg
99
+ self.dataloader_factory[stage] = dataloader_factory
100
+ self.dataloaders: dict[str, Iterable[Any] | None] = {}
101
+
102
+ def _new_dataloader(self, stage: str) -> Iterable[Any] | None:
103
+ dataloader_factory = self.dataloader_factory.get(stage, lambda: None)
104
+ self.dataloaders[stage] = (dataloader := dataloader_factory())
105
+ return dataloader
106
+
107
+ @override
108
+ def prepare_data(self) -> None:
109
+ for dataloader in self.dataloaders.values():
110
+ _call(dataloader, "prepare_data")
111
+
112
+ def _dataloader(self, stage: str) -> Iterable[Any] | None:
113
+ stage = "train" if stage == "fit" else "val" if stage == "validation" else stage
114
+ return self.dataloaders.get(stage)
115
+
116
+ @override
117
+ def setup(self, stage: str) -> None:
118
+ device = self.trainer and self.trainer.strategy and self.trainer.strategy.root_device
119
+ _call(self._dataloader(stage), "set_device", device)
120
+ _call(self._dataloader(stage), "setup", stage)
121
+
122
+ # will be used once https://github.com/Lightning-AI/pytorch-lightning/pull/19601 is in effect
123
+ def on_exception(self, exception: BaseException) -> None:
124
+ for dataloader in self.dataloaders.values():
125
+ _call(dataloader, "on_exception", exception)
126
+
127
+ @override
128
+ def teardown(self, stage: str | None) -> None:
129
+ # this is it, terminating everything regardless of which stage we received this
130
+ for dataloader in self.dataloaders.values():
131
+ _call(dataloader, "teardown", stage)
132
+
133
+ @override
134
+ def train_dataloader(self) -> Iterable[Any] | None:
135
+ return self._new_dataloader("train")
136
+
137
+ @override
138
+ def val_dataloader(self) -> Iterable[Any] | None:
139
+ return self._new_dataloader("val")
140
+
141
+ @override
142
+ def predict_dataloader(self) -> Iterable[Any] | None:
143
+ return self._new_dataloader("predict")
144
+
145
+ @override
146
+ def test_dataloader(self) -> Iterable[Any] | None:
147
+ return self._new_dataloader("test")
148
+
149
+ @override
150
+ def state_dict(self) -> dict[str, Any]:
151
+ """Called when saving a checkpoint, implement to generate and save datamodule state for ShardedDataLoader.
152
+
153
+ This method iterates over each stage's dataloader to retrieve its state using the `state_dict()` method,
154
+ and saves it along with the RNG states. If a dataloader does not implement the `PersistStates` protocol,
155
+ it sets its `state_dict` attribute to `vanilla_dataloader_state_dict` to allow saving its state.
156
+
157
+ Returns:
158
+ dict[str, Any]: A dictionary containing the dataloader states and RNG states.
159
+ """
160
+ dataloader_states = {}
161
+ for stage, dataloader in self.dataloaders.items():
162
+ dataloader_states[stage] = get_rng_states()
163
+ if isinstance(dataloader, PersistStates):
164
+ try:
165
+ result_dict = dataloader.state_dict()
166
+ dataloader_states[stage].update(result_dict)
167
+ except Exception as e:
168
+ logger.warning(f"{dataloader} states can't be persisted yet: {e}.")
169
+ return dataloader_states
170
+
171
+ @override
172
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
173
+ """Called when loading a checkpoint to reload the DataModule's state for a ShardedDataLoader.
174
+
175
+ This method iterates over each stage's dataloader, loads its state from the provided `state_dict`,
176
+ and sets the RNG states. If a dataloader does not implement the `RestoreStates` protocol,
177
+ it sets its `load_state_dict` attribute to `vanilla_dataloader_load_state_dict` to allow loading its state.
178
+
179
+ Args:
180
+ state_dict (Dict[str, Any]): A dictionary containing the dataloader states and RNG states.
181
+ """
182
+ for stage, dataloader in self.dataloaders.items():
183
+ set_rng_states(state_dict[stage])
184
+ if isinstance(dataloader, RestoreStates):
185
+ try:
186
+ dataloader.load_state_dict(state_dict[stage])
187
+ except Exception as e:
188
+ logger.warning(f"{dataloader} states can't be restored yet: {e}.")
189
+
190
+ @override
191
+ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
192
+ for dataloader in self.dataloaders.values():
193
+ _call(dataloader, "on_save_checkpoint", checkpoint)
194
+
195
+ @override
196
+ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
197
+ for dataloader in self.dataloaders.values():
198
+ _call(dataloader, "on_load_checkpoint", checkpoint)
@@ -0,0 +1,19 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from .sized import SizedDataset # isort: skip
4
+ from .dict import DictDataset
5
+ from .map import MapDataset, IterableMapDataset
6
+ from .json import JsonDataset, IterableJsonDataset
7
+ from .parquet import ParquetDataset, IterableParquetDataset
8
+
9
+
10
+ __all__ = [
11
+ "SizedDataset",
12
+ "MapDataset",
13
+ "IterableMapDataset",
14
+ "DictDataset",
15
+ "JsonDataset",
16
+ "IterableJsonDataset",
17
+ "ParquetDataset",
18
+ "IterableParquetDataset",
19
+ ]
@@ -0,0 +1,78 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ from typing_extensions import override
6
+
7
+ from fkat.data.datasets import SizedDataset
8
+ from fkat.utils.config import to_primitive_container
9
+
10
+
11
+ class DictDataset(SizedDataset[tuple[str, Any], dict[str, Any]]):
12
+ """:class:`Dataset` that can get samples from one of the :class:`Dataset` using a mapping."""
13
+
14
+ def __init__(
15
+ self,
16
+ datasets: dict[str, SizedDataset[Any, dict[str, Any]]],
17
+ key: str = "dataset",
18
+ ) -> None:
19
+ """Create a :class:`Dataset` that can get samples from one of the :class:`Dataset` using a mapping.
20
+
21
+ Args:
22
+ datasets (Dict[str, SizedDataset[Any, Dict[str, Any]]]): A mapping from labels to :class:`Dataset`\\s.
23
+ key (str): The name of the field to reflect the :class:`Dataset` the samples were provided from.
24
+ Defaults to "dataset".
25
+
26
+ Returns:
27
+ None
28
+ """
29
+ self.datasets = to_primitive_container(datasets)
30
+ self.len = sum(len(dataset) for dataset in datasets.values())
31
+ self.key = key
32
+
33
+ @override
34
+ def __len__(self) -> int:
35
+ """Get :class:`Dataset` size.
36
+
37
+ Returns:
38
+ int: :class:`Dataset` size.
39
+ """
40
+ return self.len
41
+
42
+ def _wrap(self, name: str, item: dict[str, Any]) -> dict[str, Any]:
43
+ if not isinstance(item, dict) or self.key in item:
44
+ raise RuntimeError(f"Datasets must return a dict without {self.key} key")
45
+ item[self.key] = name
46
+ return item
47
+
48
+ def __getitems__(self, name_and_idxs: tuple[str, list[Any]]) -> list[dict[str, Any]]:
49
+ """Get a batch of samples from the target :class:`Dataset` at the specified indices.
50
+
51
+ Args:
52
+ name_and_idxs (Tuple[str, List[Any]]): Samples' :class:`Dataset` and indices.
53
+
54
+ Returns:
55
+ List[Dict[str, Any]]: A batch of samples.
56
+ """
57
+ name, idxs = name_and_idxs
58
+ if getitems := getattr(self.datasets[name], "__getitems__", None):
59
+ batch = getitems(idxs)
60
+ else:
61
+ batch = [self.datasets[name][idx] for idx in idxs]
62
+ for b in batch:
63
+ self._wrap(name, b)
64
+ return batch
65
+
66
+ @override
67
+ def __getitem__(self, idx: tuple[str, Any]) -> dict[str, Any]:
68
+ """Get a sample from the target :class:`Dataset` at the specified index.
69
+
70
+ Args:
71
+ idx (Tuple[str, Any]): Sample :class:`Dataset` and index.
72
+
73
+ Returns:
74
+ Dict[str, Any]: A sample.
75
+ """
76
+ name, idx_ = idx
77
+ sample = self.datasets[name][idx_]
78
+ return self._wrap(name, sample)
@@ -0,0 +1,176 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import awswrangler as s3wr
4
+ from typing import Any
5
+ from collections.abc import Iterator
6
+
7
+ import pyarrow as pa
8
+ import pyarrow.json as pj
9
+ from pyarrow.fs import FileSystem, S3FileSystem # type: ignore[possibly-unbound-import]
10
+ import pandas as pd
11
+ import numpy as np
12
+ from fkat.data.datasets import SizedDataset
13
+ from fkat.utils.pyarrow import iter_rows as pa_iter_rows
14
+ from fkat.utils.pandas import iter_rows as pd_iter_rows
15
+ from fkat.utils.boto3 import session
16
+ from torch.utils.data import IterableDataset
17
+
18
+
19
+ class IterableJsonDataset(IterableDataset[dict[str, Any]]):
20
+ """
21
+ An :class:`IterableDataset` backed by Json data.
22
+
23
+ Args:
24
+ uri (str | list[str]): URI of Parquet data.
25
+ read_options: pyarrow.json.ReadOptions, optional
26
+ Options for the JSON reader (see ReadOptions constructor for defaults).
27
+ parse_options: pyarrow.json.ParseOptions, optional
28
+ Options for the JSON parser
29
+ (see ParseOptions constructor for defaults).
30
+ memory_pool: MemoryPool, optional
31
+ Pool to allocate Table memory from.
32
+ chunk_size (int): An iterable of DataFrames is returned with maximum rows equal to the received INTEGER.
33
+ replace_nan (bool): Whether to replace np.nan as None.
34
+ Default to ``True``
35
+ s3wr_args (dict): config for s3wr.s3.read_json,
36
+ refer to https://aws-sdk-pandas.readthedocs.io/en/3.5.1/stubs/awswrangler.s3.read_parquet.html
37
+
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ uri: str | list[str],
43
+ read_options: pa.json.ReadOptions | None = None,
44
+ parse_options: pa.json.ParseOptions | None = None,
45
+ memory_pool: pa.MemoryPool | None = None,
46
+ chunk_size: int = 10000,
47
+ replace_nan: bool = True,
48
+ **s3wr_args: Any,
49
+ ) -> None:
50
+ fs: FileSystem
51
+ path: str
52
+ if isinstance(uri, str):
53
+ fs, path = FileSystem.from_uri(uri)
54
+ else:
55
+ fs, path = FileSystem.from_uri(uri[0])
56
+ self.s3_file = isinstance(fs, S3FileSystem)
57
+ if self.s3_file:
58
+ self.uri = uri
59
+ self.chunk_size = chunk_size
60
+ self.replace_nan = replace_nan
61
+ self.s3wr_args = s3wr_args
62
+ else:
63
+ assert isinstance(uri, str), "IterableJsonDataset can only accept uri as str"
64
+ with fs.open_input_file(path) as f:
65
+ self.tbl = pj.read_json(
66
+ f, read_options=read_options, parse_options=parse_options, memory_pool=memory_pool
67
+ )
68
+ self.chunk_size = chunk_size
69
+
70
+ def __iter__(self) -> Iterator[dict[str, Any]]:
71
+ """Creates dataset iterator.
72
+ Returns:
73
+ Iterator[dict[str, Any]]: dataset iterator
74
+ """
75
+ if self.s3_file:
76
+ return pd_iter_rows(
77
+ s3wr.s3.read_json(
78
+ self.uri,
79
+ lines=True,
80
+ chunksize=self.chunk_size,
81
+ boto3_session=session(clients=["s3"]),
82
+ path_suffix="json",
83
+ **self.s3wr_args,
84
+ ),
85
+ self.replace_nan,
86
+ )
87
+ else:
88
+ return pa_iter_rows(self.tbl, chunk_size=self.chunk_size)
89
+
90
+
91
+ class JsonDataset(SizedDataset[int, dict[str, Any]]):
92
+ """
93
+ Create a :class:`Dataset` from JSON data at the specified URI.
94
+
95
+ Args:
96
+ uri (str | list[str]): URI of JSON data.
97
+ read_options (pa.json.ReadOptions | None): JSON read options.
98
+ parse_options (pa.json.ParseOptions | None): JSON parse options.
99
+ memory_pool (pa.MemoryPool | None): JSON processing memory pool configuration.
100
+ replace_nan (bool): Whether to replace np.nan as None.
101
+ Default to ``True``
102
+ s3wr_args (Any): config for s3wr.s3.read_json,
103
+ refer to https://aws-sdk-pandas.readthedocs.io/en/3.5.1/stubs/awswrangler.s3.read_json.html
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ uri: str | list[str],
109
+ read_options: pa.json.ReadOptions | None = None,
110
+ parse_options: pa.json.ParseOptions | None = None,
111
+ memory_pool: pa.MemoryPool | None = None,
112
+ replace_nan: bool = True,
113
+ **s3wr_args: Any,
114
+ ) -> None:
115
+ fs: FileSystem
116
+ path: str
117
+ if isinstance(uri, str):
118
+ fs, path = FileSystem.from_uri(uri)
119
+ else:
120
+ fs, path = FileSystem.from_uri(uri[0])
121
+ if isinstance(fs, S3FileSystem):
122
+ self.df = s3wr.s3.read_json(
123
+ uri, lines=True, boto3_session=session(clients=["s3"]), path_suffix="json", **s3wr_args
124
+ )
125
+ if replace_nan:
126
+ self.df = self.df.replace({np.nan: None})
127
+ else:
128
+ path_list: list[str] = []
129
+ if isinstance(uri, str):
130
+ path_list.append(path)
131
+ else:
132
+ for each in uri:
133
+ _, path = FileSystem.from_uri(each)
134
+ path_list.append(path)
135
+ df = []
136
+ for each in path_list:
137
+ with fs.open_input_file(each) as f:
138
+ tbl = pj.read_json(
139
+ f, read_options=read_options, parse_options=parse_options, memory_pool=memory_pool
140
+ )
141
+ df.append(tbl.to_pandas())
142
+ self.df = pd.concat(df)
143
+
144
+ def __len__(self) -> int:
145
+ """Get :class:`Dataset` size.
146
+
147
+ Returns:
148
+ int: :class:`Dataset` size.
149
+ """
150
+ return len(self.df)
151
+
152
+ def __getitems__(self, idxs: list[int]) -> list[dict[str, Any]]:
153
+ """Get a batch of samples at the specified indices.
154
+
155
+ Args:
156
+ idxs (list[int]): Samples' indices.
157
+
158
+ Returns:
159
+ list[dict[str, Any]]: A batch of samples.
160
+ """
161
+ series = self.df.iloc[idxs]
162
+ samples = [series.iloc[i].to_dict() for i in range(len(idxs))]
163
+ return samples
164
+
165
+ def __getitem__(self, idx: int) -> dict[str, Any]:
166
+ """Get a sample at the specified index.
167
+
168
+ Args:
169
+ idx (int): Sample index.
170
+
171
+ Returns:
172
+ dict[str, Any]: A sample.
173
+ """
174
+ series = self.df.iloc[idx]
175
+ sample = series.to_dict()
176
+ return sample