fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import logging
|
|
4
|
+
import lightning as L
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from fkat.pytorch.schedule import Schedule
|
|
8
|
+
from fkat.pytorch.callbacks.loggers import CallbackLogger
|
|
9
|
+
|
|
10
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HasTag(Schedule):
|
|
14
|
+
"""
|
|
15
|
+
A schedule that activates only when a specific MLflow tag is present AND the trigger schedule is satisfied.
|
|
16
|
+
|
|
17
|
+
This schedule combines another schedule (the trigger schedule) with MLflow tag validation.
|
|
18
|
+
It allows callbacks to be dynamically enabled or disabled through experiment configuration rather than
|
|
19
|
+
code changes, which is particularly useful for performance-intensive callbacks like FLOP measurement
|
|
20
|
+
or detailed logging that should only run conditionally.
|
|
21
|
+
|
|
22
|
+
The schedule checks two conditions:
|
|
23
|
+
1. If the trigger schedule is satisfied
|
|
24
|
+
2. If the specified MLflow tag exists in the current MLflow run
|
|
25
|
+
|
|
26
|
+
Both conditions must be true for the schedule to activate. If the trigger schedule doesn't activate or
|
|
27
|
+
the trainer is not provided, the schedule will never activate.
|
|
28
|
+
|
|
29
|
+
Note:
|
|
30
|
+
- The trainer can be optionally provided to the ``check`` method for MLflow tag validation. If trainer is None,
|
|
31
|
+
the schedule will never activate.
|
|
32
|
+
- Tag checking occurs only when the trigger schedule condition is already satisfied,
|
|
33
|
+
minimizing MLflow API calls.
|
|
34
|
+
- If an exception occurs during tag checking, it will be logged and the schedule will not activate.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
Python code example::
|
|
38
|
+
|
|
39
|
+
# Create a schedule that checks every 5 batches if the 'enable_flops' tag exists
|
|
40
|
+
from fkat.pytorch.schedule import Every
|
|
41
|
+
|
|
42
|
+
trigger = Every(n_batches=5)
|
|
43
|
+
flops_schedule = HasTag(tag="enable_flops", schedule=trigger)
|
|
44
|
+
flops_callback = Flops(schedule=flops_schedule)
|
|
45
|
+
trainer = L.Trainer(callbacks=[flops_callback])
|
|
46
|
+
|
|
47
|
+
Hydra configuration example:
|
|
48
|
+
|
|
49
|
+
.. code-block:: yaml
|
|
50
|
+
|
|
51
|
+
# In your config.yaml file
|
|
52
|
+
callbacks:
|
|
53
|
+
- _target_: fkat.pytorch.callbacks.profiling.Flops
|
|
54
|
+
schedule:
|
|
55
|
+
_target_: fkat.pytorch.schedule.mlflow.HasTag
|
|
56
|
+
tag: ENABLE_FLOPS
|
|
57
|
+
schedule:
|
|
58
|
+
_target_: fkat.pytorch.schedule.Every
|
|
59
|
+
n_steps: 20
|
|
60
|
+
|
|
61
|
+
# Another example using Fixed schedule
|
|
62
|
+
callbacks:
|
|
63
|
+
- _target_: fkat.pytorch.callbacks.heartbeat.Heartbeat
|
|
64
|
+
schedule:
|
|
65
|
+
_target_: fkat.pytorch.schedule.mlflow.HasTag
|
|
66
|
+
tag: ENABLE_HEARTBEAT
|
|
67
|
+
schedule:
|
|
68
|
+
_target_: fkat.pytorch.schedule.Fixed
|
|
69
|
+
warmup_steps: 100
|
|
70
|
+
active_steps: 1000
|
|
71
|
+
|
|
72
|
+
# Example with Elapsed time-based schedule
|
|
73
|
+
callbacks:
|
|
74
|
+
- _target_: fkat.pytorch.callbacks.custom_logging.DetailedMetrics
|
|
75
|
+
schedule:
|
|
76
|
+
_target_: fkat.pytorch.schedule.mlflow.HasTag
|
|
77
|
+
tag: DETAILED_LOGGING
|
|
78
|
+
schedule:
|
|
79
|
+
_target_: fkat.pytorch.schedule.Elapsed
|
|
80
|
+
interval: ${timedelta:minutes=15}
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, tag: str, schedule: Schedule) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Initialize a new MLflow HasTag schedule.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
tag (str): The name of the tag that must be present in the MLflow run.
|
|
89
|
+
schedule (Schedule): The schedule that determines when to check for the tag.
|
|
90
|
+
This can be any implementation of the Schedule protocol (e.g., Every, Fixed, Elapsed).
|
|
91
|
+
"""
|
|
92
|
+
self._tag: str = tag
|
|
93
|
+
self._schedule: Schedule = schedule
|
|
94
|
+
self._cb_logger: CallbackLogger | None = None
|
|
95
|
+
|
|
96
|
+
@override
|
|
97
|
+
def check(
|
|
98
|
+
self,
|
|
99
|
+
*,
|
|
100
|
+
stage: str | None = None,
|
|
101
|
+
batch_idx: int | None = None,
|
|
102
|
+
step: int | None = None,
|
|
103
|
+
trainer: L.Trainer | None = None,
|
|
104
|
+
) -> bool:
|
|
105
|
+
"""
|
|
106
|
+
Check if the schedule should activate based on the trigger schedule and MLflow tag presence.
|
|
107
|
+
|
|
108
|
+
This method first checks if the trigger schedule is satisfied.
|
|
109
|
+
If this condition is met, it then checks if the specified MLflow tag exists in the current run.
|
|
110
|
+
Both conditions must be true for the method to return True.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
stage (str, optional): Current training stage (e.g., "train", "validate", "test").
|
|
114
|
+
Passed to the trigger schedule.
|
|
115
|
+
batch_idx (int, optional): Current batch index within the epoch.
|
|
116
|
+
Passed to the trigger schedule.
|
|
117
|
+
step (int, optional): Current global step (cumulative across epochs).
|
|
118
|
+
Passed to the trigger schedule.
|
|
119
|
+
trainer (L.Trainer, optional): The Lightning Trainer instance.
|
|
120
|
+
Required for MLflow tag validation.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
bool: True if both the trigger schedule is satisfied AND the specified tag is present,
|
|
124
|
+
False otherwise.
|
|
125
|
+
|
|
126
|
+
Note:
|
|
127
|
+
- The trainer must be provided for MLflow tag validation.
|
|
128
|
+
- Tag checking occurs only when the trigger schedule is already satisfied.
|
|
129
|
+
- If an exception occurs during tag checking, it will be logged as a warning
|
|
130
|
+
and the method will return False.
|
|
131
|
+
"""
|
|
132
|
+
triggered = self._schedule.check(stage=stage, batch_idx=batch_idx, step=step, trainer=trainer)
|
|
133
|
+
if not triggered or trainer is None:
|
|
134
|
+
return False
|
|
135
|
+
try:
|
|
136
|
+
if self._cb_logger is None:
|
|
137
|
+
self._cb_logger = CallbackLogger(trainer)
|
|
138
|
+
|
|
139
|
+
tags = self._cb_logger.tags()
|
|
140
|
+
return self._tag in tags
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.warning(f"Error when checking if tag {self._tag} exists: {e}")
|
|
143
|
+
return False
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from typing import TypeVar, ParamSpec
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing_extensions import overload
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_rank() -> int | None:
|
|
11
|
+
return int(os.getenv("RANK", "0"))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_local_rank() -> int | None:
|
|
15
|
+
return int(os.getenv("LOCAL_RANK", "0"))
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T")
|
|
19
|
+
P = ParamSpec("P")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@overload
|
|
23
|
+
def local_rank_zero_only(fn: Callable[P, T]) -> Callable[P, T | None]: ...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@overload
|
|
27
|
+
def local_rank_zero_only(fn: Callable[P, T], default: T) -> Callable[P, T]: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def local_rank_zero_only(fn: Callable[P, T], default: T | None = None) -> Callable[P, T | None]:
|
|
31
|
+
"""Wrap a function to call internal function only in rank zero.
|
|
32
|
+
|
|
33
|
+
Function that can be used as a decorator to enable a function/method being called only on global rank 0.
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
@wraps(fn)
|
|
38
|
+
def wrapped_fn(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
|
39
|
+
local_rank = getattr(local_rank_zero_only, "local_rank", None)
|
|
40
|
+
if local_rank is None:
|
|
41
|
+
raise RuntimeError("The `local_rank_zero_only.local_rank` needs to be set before use")
|
|
42
|
+
if local_rank == 0:
|
|
43
|
+
return fn(*args, **kwargs)
|
|
44
|
+
return default
|
|
45
|
+
|
|
46
|
+
return wrapped_fn
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
local_rank_zero_only.local_rank = getattr(local_rank_zero_only, "local_rank", get_local_rank() or 0) # type: ignore[attr-defined]
|
fkat/test.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#!/usr/bin/env python
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
The ``fkat.test`` entrypoint processes the provided config,
|
|
7
|
+
instatiates the ``trainer``, ``model`` and ``data`` sections and calls ``trainer.test()``.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import hydra
|
|
11
|
+
import lightning as L
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
|
|
14
|
+
from fkat import initialize, run_main
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@hydra.main(version_base="1.3")
|
|
18
|
+
def main(cfg: DictConfig) -> None:
|
|
19
|
+
s = initialize(cfg)
|
|
20
|
+
kwargs = {
|
|
21
|
+
"ckpt_path": s.ckpt_path,
|
|
22
|
+
}
|
|
23
|
+
if isinstance(s.data, L.LightningDataModule):
|
|
24
|
+
kwargs["datamodule"] = s.data
|
|
25
|
+
else:
|
|
26
|
+
kwargs["test_dataloaders"] = s.data.test_dataloader() if s.data else None
|
|
27
|
+
s.trainer.test(s.model, **kwargs)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if __name__ == "__main__":
|
|
31
|
+
run_main(main)
|
fkat/train.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#!/usr/bin/env python
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
The ``fkat.train`` entrypoint processes the provided config,
|
|
7
|
+
instatiates the ``trainer``, ``model`` and ``data`` sections and calls ``trainer.fit()``.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import hydra
|
|
11
|
+
import lightning as L
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
|
|
14
|
+
from fkat import initialize, run_main
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@hydra.main(version_base="1.3")
|
|
18
|
+
def main(cfg: DictConfig) -> None:
|
|
19
|
+
s = initialize(cfg)
|
|
20
|
+
kwargs = {
|
|
21
|
+
"ckpt_path": s.ckpt_path,
|
|
22
|
+
}
|
|
23
|
+
if isinstance(s.data, L.LightningDataModule):
|
|
24
|
+
kwargs["datamodule"] = s.data
|
|
25
|
+
else:
|
|
26
|
+
kwargs["train_dataloaders"] = s.data.train_dataloader() if s.data else None
|
|
27
|
+
kwargs["val_dataloaders"] = s.data.val_dataloader() if s.data else None
|
|
28
|
+
s.trainer.fit(s.model, **kwargs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if __name__ == "__main__":
|
|
32
|
+
run_main(main)
|
fkat/utils/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def assert_not_none(obj: T | None, name: str = "obj") -> T:
|
|
10
|
+
assert obj is not None, f"{name} cannot be None"
|
|
11
|
+
return obj
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def safe_timestamp(dt: datetime | None = None) -> str:
|
|
15
|
+
"""
|
|
16
|
+
Generate a filesystem-safe timestamp string.
|
|
17
|
+
|
|
18
|
+
Format: YYYY-MM-DD_HH-MM-SS-mmm (e.g., 2026-01-14_16-09-15-123)
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
dt: datetime object to format. If None, uses current UTC time.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Formatted timestamp string safe for use in filenames
|
|
25
|
+
"""
|
|
26
|
+
if dt is None:
|
|
27
|
+
dt = datetime.now(timezone.utc)
|
|
28
|
+
return dt.strftime("%Y-%m-%d_%H-%M-%S-") + f"{dt.microsecond // 1000:03d}"
|
fkat/utils/aws/imds.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""
|
|
4
|
+
EC2 Instance Metadata Service related methods
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import socket
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from functools import lru_cache
|
|
11
|
+
|
|
12
|
+
import requests
|
|
13
|
+
from requests import HTTPError
|
|
14
|
+
|
|
15
|
+
from fkat.utils.logging import rank0_logger
|
|
16
|
+
|
|
17
|
+
IMDS_URL = "http://169.254.169.254/latest"
|
|
18
|
+
IMDS_METADATA_URL = f"{IMDS_URL}/meta-data"
|
|
19
|
+
IMDS_V2_TOKEN_URL = f"{IMDS_URL}/api/token"
|
|
20
|
+
NULL = "_NULL_" # sentinel value used to mark null (not-available) values
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
log: logging.Logger = rank0_logger(__name__)
|
|
24
|
+
|
|
25
|
+
Token = str
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class InstanceMetadata:
|
|
30
|
+
"""
|
|
31
|
+
Struct representing the instance metadata as fetched from IMDS on the current host.
|
|
32
|
+
Use :py:func:`fkat.utils.aws.imds.instance_metadata` to get a filled-out
|
|
33
|
+
instance of this object.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
instance_id: str
|
|
37
|
+
instance_type: str
|
|
38
|
+
hostname: str
|
|
39
|
+
public_hostname: str
|
|
40
|
+
local_hostname: str
|
|
41
|
+
local_ipv4: str
|
|
42
|
+
availability_zone: str
|
|
43
|
+
region: str
|
|
44
|
+
ami_id: str
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@lru_cache
|
|
48
|
+
def fetch(metadata: str = "", token: Token | None = None) -> str | None:
|
|
49
|
+
"""
|
|
50
|
+
Fetches the specified ``metadata`` from EC2's Instance MetaData Service (IMDS) running
|
|
51
|
+
on the current host, by sending an HTTP GET request to ``http://169.254.169.254/latest/meta-data/<metadata>``.
|
|
52
|
+
|
|
53
|
+
To get a list of all valid values of ``metadata`` run this method with no arguments then split
|
|
54
|
+
the return value by new-line.
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
Arguments:
|
|
58
|
+
metadata: Name of the instance metadata to query (e.g. ``instance-type``)
|
|
59
|
+
token: IMDS token
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
the specified ``metadata`` or ``None`` if IMDS cannot be reached
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
response = requests.get(f"{IMDS_METADATA_URL}/{metadata}", headers={"X-aws-ec2-metadata-token": token or ""})
|
|
67
|
+
except Exception as e:
|
|
68
|
+
log.warning("Error querying IMDSV2 instance metadata won't be available", exc_info=e)
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
if response.ok:
|
|
72
|
+
return response.text
|
|
73
|
+
else: # response NOT ok
|
|
74
|
+
try:
|
|
75
|
+
response.raise_for_status()
|
|
76
|
+
except HTTPError:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
raise AssertionError("Unreachable code!")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def token(timeout: int = 60) -> Token | None:
|
|
83
|
+
"""
|
|
84
|
+
Fetches IMDS ``token`` from EC2's Instance MetaData Service (IMDSV2) running
|
|
85
|
+
on the current host, by sending an HTTP GET request to ``http://169.254.169.254/latest/meta-data/<metadata>``.
|
|
86
|
+
|
|
87
|
+
Arguments:
|
|
88
|
+
timeout: request timeout
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
the specified ``token`` or ``""`` if IMDSV2 cannot be reached
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
response = requests.put(
|
|
96
|
+
IMDS_V2_TOKEN_URL, headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"}, timeout=timeout
|
|
97
|
+
)
|
|
98
|
+
except Exception as e:
|
|
99
|
+
log.warning("Error querying IMDSV2 instance token won't be available", exc_info=e)
|
|
100
|
+
return ""
|
|
101
|
+
|
|
102
|
+
if response.ok:
|
|
103
|
+
return response.text
|
|
104
|
+
else: # response NOT ok
|
|
105
|
+
error_code_msg = f"IMDS Token response is not ok with status code {response.status_code}"
|
|
106
|
+
log.warning("Error querying IMDSV2 instance token won't be available", exc_info=Exception(error_code_msg))
|
|
107
|
+
return ""
|
|
108
|
+
|
|
109
|
+
raise AssertionError("Unreachable code!")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@lru_cache
|
|
113
|
+
def instance_metadata() -> InstanceMetadata:
|
|
114
|
+
"""
|
|
115
|
+
Fetches IMDS instance metadata for the current host from EC2's Instance Metadata Service (IMDS),
|
|
116
|
+
which typically runs on localhost at ``http://169.254.169.254``.
|
|
117
|
+
If IMDS cannot be reached for any reason returns an instance of :py:class:`InstanceMetadata`
|
|
118
|
+
where all the fields are empty strings.
|
|
119
|
+
|
|
120
|
+
.. note::
|
|
121
|
+
This method is memoized (value is cached) hence, only the first call
|
|
122
|
+
will actually hit IMDS, and subsequent calls will return the memoized
|
|
123
|
+
value. Therefore, it is ok to call this function multiple times.
|
|
124
|
+
|
|
125
|
+
"""
|
|
126
|
+
tkn = token()
|
|
127
|
+
return InstanceMetadata(
|
|
128
|
+
instance_id=fetch("instance-id", tkn) or "localhost",
|
|
129
|
+
instance_type=fetch("instance-type", tkn) or NULL,
|
|
130
|
+
availability_zone=fetch("placement/availability-zone", tkn) or NULL,
|
|
131
|
+
region=fetch("placement/region", tkn) or NULL,
|
|
132
|
+
hostname=fetch("hostname", tkn) or socket.gethostname(),
|
|
133
|
+
local_ipv4=fetch("local-ipv4", tkn) or NULL,
|
|
134
|
+
public_hostname=fetch("public-hostname", tkn) or NULL,
|
|
135
|
+
local_hostname=fetch("local-hostname", tkn) or NULL,
|
|
136
|
+
ami_id=fetch("ami-id", tkn) or NULL,
|
|
137
|
+
)
|
fkat/utils/boto3.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import boto3
|
|
6
|
+
from botocore.config import Config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def session(
|
|
10
|
+
max_attempts: int = 6,
|
|
11
|
+
mode: Literal["legacy", "standard", "adaptive"] = "standard",
|
|
12
|
+
clients: list[str] | None = None,
|
|
13
|
+
) -> boto3.Session:
|
|
14
|
+
config = Config(
|
|
15
|
+
retries={
|
|
16
|
+
"max_attempts": max_attempts,
|
|
17
|
+
"mode": mode,
|
|
18
|
+
}
|
|
19
|
+
)
|
|
20
|
+
session = boto3.Session()
|
|
21
|
+
if clients:
|
|
22
|
+
for client in clients:
|
|
23
|
+
session.client(client, config=config) # type: ignore
|
|
24
|
+
return session
|
fkat/utils/config.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any, Protocol
|
|
5
|
+
from tempfile import TemporaryDirectory
|
|
6
|
+
|
|
7
|
+
from lightning import Trainer as LightningTrainer
|
|
8
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
9
|
+
from omegaconf import OmegaConf
|
|
10
|
+
|
|
11
|
+
from fkat.utils.mlflow import mlflow_logger
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Trainer(Protocol):
|
|
15
|
+
"""
|
|
16
|
+
Protocol defining the interface for training models.
|
|
17
|
+
|
|
18
|
+
This protocol establishes the required methods that any trainer implementation
|
|
19
|
+
must provide. It serves as a structural subtyping interface for objects that
|
|
20
|
+
can train, evaluate, and make predictions with machine learning models.
|
|
21
|
+
|
|
22
|
+
Implementations of this protocol should handle the complete training lifecycle,
|
|
23
|
+
including model fitting, prediction, testing, and validation.
|
|
24
|
+
|
|
25
|
+
Note:
|
|
26
|
+
As a Protocol class, this is not meant to be instantiated directly but
|
|
27
|
+
rather used for type checking and interface definition.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def fit(self, *args: Any, **kwargs: Any) -> Any:
|
|
31
|
+
"""
|
|
32
|
+
Train a model on the provided data.
|
|
33
|
+
|
|
34
|
+
This method handles the model training process, including data loading,
|
|
35
|
+
optimization, and potentially checkpointing.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
*args: Variable length argument list.
|
|
39
|
+
**kwargs: Arbitrary keyword arguments.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Any: Training results, which may include training history,
|
|
43
|
+
trained model, or other relevant information.
|
|
44
|
+
"""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
def predict(self, *args: Any, **kwargs: Any) -> Any:
|
|
48
|
+
"""
|
|
49
|
+
Generate predictions using a trained model.
|
|
50
|
+
|
|
51
|
+
This method applies the trained model to new data to produce predictions.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
*args: Variable length argument list.
|
|
55
|
+
**kwargs: Arbitrary keyword arguments.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Any: Model predictions, which could be probabilities, class labels,
|
|
59
|
+
regression values, or other outputs depending on the model type.
|
|
60
|
+
"""
|
|
61
|
+
...
|
|
62
|
+
|
|
63
|
+
def test(self, *args: Any, **kwargs: Any) -> Any:
|
|
64
|
+
"""
|
|
65
|
+
Evaluate a trained model on test data.
|
|
66
|
+
|
|
67
|
+
This method assesses model performance on a test dataset, typically
|
|
68
|
+
calculating metrics such as accuracy, loss, or other relevant measures.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
*args: Variable length argument list.
|
|
72
|
+
**kwargs: Arbitrary keyword arguments.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Any: Test results, which may include metrics, predictions,
|
|
76
|
+
or other evaluation information.
|
|
77
|
+
"""
|
|
78
|
+
...
|
|
79
|
+
|
|
80
|
+
def validate(self, *args: Any, **kwargs: Any) -> Any:
|
|
81
|
+
"""
|
|
82
|
+
Evaluate a trained model on validation data.
|
|
83
|
+
|
|
84
|
+
This method assesses model performance on a validation dataset, which is
|
|
85
|
+
typically used during the training process for hyperparameter tuning
|
|
86
|
+
or early stopping.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
*args: Variable length argument list.
|
|
90
|
+
**kwargs: Arbitrary keyword arguments.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Any: Validation results, which may include metrics, predictions,
|
|
94
|
+
or other evaluation information.
|
|
95
|
+
"""
|
|
96
|
+
...
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class SingletonResolver:
|
|
100
|
+
"""
|
|
101
|
+
A singleton class that resolves and manages training components.
|
|
102
|
+
|
|
103
|
+
This class serves as a central registry for training-related objects such as
|
|
104
|
+
trainers, data, models, and checkpoint paths. It ensures that these components
|
|
105
|
+
are accessible throughout the application while maintaining a single instance.
|
|
106
|
+
|
|
107
|
+
Example:
|
|
108
|
+
>>> resolver = SingletonResolver()
|
|
109
|
+
>>> resolver.trainer = MyTrainer()
|
|
110
|
+
>>> resolver.model = MyModel()
|
|
111
|
+
>>> # Access the same instance elsewhere
|
|
112
|
+
>>> same_resolver = SingletonResolver()
|
|
113
|
+
>>> assert same_resolver.trainer is resolver.trainer
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
trainer: Trainer
|
|
117
|
+
"""The trainer instance responsible for executing the training process."""
|
|
118
|
+
|
|
119
|
+
data: Any | None = None
|
|
120
|
+
"""The dataset or data loader used for training and evaluation.
|
|
121
|
+
Defaults to None."""
|
|
122
|
+
|
|
123
|
+
model: Any | None = None
|
|
124
|
+
"""The model architecture to be trained or evaluated.
|
|
125
|
+
Defaults to None."""
|
|
126
|
+
|
|
127
|
+
ckpt_path: Any | None = None
|
|
128
|
+
"""Path to checkpoint files for model loading/saving.
|
|
129
|
+
Defaults to None."""
|
|
130
|
+
|
|
131
|
+
return_predictions: Any | None = None
|
|
132
|
+
"""Flag or configuration for returning predictions.
|
|
133
|
+
Defaults to None."""
|
|
134
|
+
|
|
135
|
+
tuners: Any | None = None
|
|
136
|
+
"""Hyperparameter tuners or optimization components.
|
|
137
|
+
Defaults to None."""
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def register_singleton_resolver() -> Any:
|
|
141
|
+
resolver = SingletonResolver()
|
|
142
|
+
|
|
143
|
+
def resolve(key: str) -> Any:
|
|
144
|
+
res = resolver
|
|
145
|
+
for attr in key.split("."):
|
|
146
|
+
res = getattr(res, attr)
|
|
147
|
+
return res
|
|
148
|
+
|
|
149
|
+
OmegaConf.register_new_resolver("fkat", resolve)
|
|
150
|
+
return resolver
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def to_str(cfg: Any) -> str:
|
|
154
|
+
"""
|
|
155
|
+
Convert a configuration object to a formatted string representation.
|
|
156
|
+
|
|
157
|
+
This function takes a configuration object and converts it to a human-readable
|
|
158
|
+
YAML string. It's useful for logging, debugging, or displaying configuration settings.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
cfg (Any): The configuration object to convert to string.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str: A formatted string representation of the configuration.
|
|
165
|
+
|
|
166
|
+
Example:
|
|
167
|
+
>>> config = {"model": {"type": "resnet", "layers": 50}, "batch_size": 32}
|
|
168
|
+
>>> print(to_str(config))
|
|
169
|
+
Config:
|
|
170
|
+
model:
|
|
171
|
+
type: resnet
|
|
172
|
+
layers: 50
|
|
173
|
+
batch_size: 32
|
|
174
|
+
"""
|
|
175
|
+
return "Config:\n" + OmegaConf.to_yaml(cfg)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def to_primitive_container(cfg: Any) -> Any:
|
|
179
|
+
if OmegaConf.is_config(cfg):
|
|
180
|
+
return OmegaConf.to_container(cfg)
|
|
181
|
+
return cfg
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@rank_zero_only
|
|
185
|
+
def save(cfg: Any, trainer: LightningTrainer) -> None:
|
|
186
|
+
yaml_str = OmegaConf.to_yaml(cfg)
|
|
187
|
+
with TemporaryDirectory() as temp_dir:
|
|
188
|
+
yaml_path = os.path.join(temp_dir, "config.yaml")
|
|
189
|
+
os.makedirs(os.path.dirname(yaml_path), exist_ok=True)
|
|
190
|
+
with open(yaml_path, "w") as f:
|
|
191
|
+
f.write(yaml_str)
|
|
192
|
+
if mlflow := mlflow_logger(trainer):
|
|
193
|
+
if mlflow.run_id:
|
|
194
|
+
mlflow.experiment.log_artifact(mlflow.run_id, yaml_path)
|