fkat 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fkat/__init__.py +147 -0
- fkat/data/__init__.py +15 -0
- fkat/data/data_module.py +198 -0
- fkat/data/datasets/__init__.py +19 -0
- fkat/data/datasets/dict.py +78 -0
- fkat/data/datasets/json.py +176 -0
- fkat/data/datasets/map.py +90 -0
- fkat/data/datasets/parquet.py +242 -0
- fkat/data/datasets/sized.py +31 -0
- fkat/data/dict.py +42 -0
- fkat/data/samplers/__init__.py +9 -0
- fkat/data/samplers/dict.py +38 -0
- fkat/data/samplers/sized.py +16 -0
- fkat/data/samplers/strategies.py +68 -0
- fkat/data/sharded.py +718 -0
- fkat/data/shm.py +364 -0
- fkat/predict.py +32 -0
- fkat/py.typed +0 -0
- fkat/pytorch/__init__.py +3 -0
- fkat/pytorch/actions/__init__.py +11 -0
- fkat/pytorch/actions/aws/__init__.py +3 -0
- fkat/pytorch/actions/aws/batch.py +29 -0
- fkat/pytorch/actions/aws/ec2.py +61 -0
- fkat/pytorch/callbacks/__init__.py +2 -0
- fkat/pytorch/callbacks/cuda/__init__.py +16 -0
- fkat/pytorch/callbacks/cuda/cache.py +115 -0
- fkat/pytorch/callbacks/cuda/memory.py +200 -0
- fkat/pytorch/callbacks/cuda/nsys.py +199 -0
- fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
- fkat/pytorch/callbacks/cuda/xid.py +173 -0
- fkat/pytorch/callbacks/debugging/__init__.py +9 -0
- fkat/pytorch/callbacks/debugging/introspection.py +569 -0
- fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
- fkat/pytorch/callbacks/gc.py +146 -0
- fkat/pytorch/callbacks/loggers.py +211 -0
- fkat/pytorch/callbacks/logging/__init__.py +12 -0
- fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
- fkat/pytorch/callbacks/logging/throughput.py +253 -0
- fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
- fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
- fkat/pytorch/callbacks/monitoring/crash.py +162 -0
- fkat/pytorch/callbacks/monitoring/dp.py +130 -0
- fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
- fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
- fkat/pytorch/callbacks/profiling/__init__.py +13 -0
- fkat/pytorch/callbacks/profiling/flops.py +574 -0
- fkat/pytorch/callbacks/profiling/memray.py +212 -0
- fkat/pytorch/callbacks/profiling/torch.py +197 -0
- fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
- fkat/pytorch/loggers.py +284 -0
- fkat/pytorch/schedule/__init__.py +27 -0
- fkat/pytorch/schedule/base.py +308 -0
- fkat/pytorch/schedule/mlflow.py +143 -0
- fkat/pytorch/utilities.py +49 -0
- fkat/test.py +31 -0
- fkat/train.py +32 -0
- fkat/utils/__init__.py +28 -0
- fkat/utils/aws/__init__.py +3 -0
- fkat/utils/aws/imds.py +137 -0
- fkat/utils/boto3.py +24 -0
- fkat/utils/config.py +194 -0
- fkat/utils/cuda/__init__.py +3 -0
- fkat/utils/cuda/preflight/__init__.py +3 -0
- fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
- fkat/utils/cuda/preflight/health_check/constants.py +23 -0
- fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
- fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
- fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
- fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
- fkat/utils/cuda/preflight/health_check/logger.py +205 -0
- fkat/utils/cuda/preflight/health_check/timer.py +31 -0
- fkat/utils/cuda/preflight/run.py +560 -0
- fkat/utils/cuda/xid.py +48 -0
- fkat/utils/logging.py +28 -0
- fkat/utils/mlflow.py +33 -0
- fkat/utils/pandas.py +25 -0
- fkat/utils/pdb.py +84 -0
- fkat/utils/pool.py +81 -0
- fkat/utils/profiler.py +18 -0
- fkat/utils/pyarrow.py +21 -0
- fkat/utils/rng.py +27 -0
- fkat/utils/shm.py +184 -0
- fkat/validate.py +31 -0
- fkat-0.1.2.dist-info/METADATA +134 -0
- fkat-0.1.2.dist-info/RECORD +88 -0
- fkat-0.1.2.dist-info/WHEEL +4 -0
- fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
- fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import hashlib
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
import string
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from typing import Any
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
|
|
11
|
+
import mlflow
|
|
12
|
+
import requests
|
|
13
|
+
from requests.exceptions import RequestException
|
|
14
|
+
import torch
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
import time
|
|
17
|
+
import logging
|
|
18
|
+
from fkat.utils.aws.imds import InstanceMetadata
|
|
19
|
+
from fkat.utils.cuda.preflight.health_check.constants import (
|
|
20
|
+
AWS_BATCH_JOB_ID,
|
|
21
|
+
HEALTH_CHECK_TIMEOUT_SECS,
|
|
22
|
+
)
|
|
23
|
+
from pynvml import (
|
|
24
|
+
NVMLError,
|
|
25
|
+
nvmlDeviceGetHandleByIndex,
|
|
26
|
+
nvmlDeviceGetSerial,
|
|
27
|
+
nvmlDeviceGetUUID,
|
|
28
|
+
nvmlInit,
|
|
29
|
+
nvmlShutdown,
|
|
30
|
+
)
|
|
31
|
+
import torch.multiprocessing as mp
|
|
32
|
+
import torch.distributed as dist
|
|
33
|
+
from torch.multiprocessing import Queue
|
|
34
|
+
|
|
35
|
+
torch.manual_seed(12345)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class UniqueID:
|
|
40
|
+
rank: int
|
|
41
|
+
world_size: int
|
|
42
|
+
local_rank: int
|
|
43
|
+
num_nodes: int
|
|
44
|
+
node_rank: int
|
|
45
|
+
gpu_per_node: int
|
|
46
|
+
gpu_hash_id: str
|
|
47
|
+
master_addr: str
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class InstanceStats:
|
|
51
|
+
def __init__(self, instance_metadata: InstanceMetadata, gpu_info: dict[str | int, dict[str, Any]]) -> None:
|
|
52
|
+
self.instance_id = instance_metadata.instance_id
|
|
53
|
+
self.instance_type = instance_metadata.instance_type
|
|
54
|
+
self.instance_ipv4 = instance_metadata.local_ipv4
|
|
55
|
+
self.instance_hostname = instance_metadata.hostname
|
|
56
|
+
self.instance_region = instance_metadata.region
|
|
57
|
+
self.scan_datetime = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
|
|
58
|
+
self.gpu_info = gpu_info
|
|
59
|
+
self.healthy = True
|
|
60
|
+
|
|
61
|
+
def upload_mlflow(self, instance_gpu_hash_id: str) -> None:
|
|
62
|
+
"""Upload instance stats to mlflow.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
instance_gpu_hash_id (str): hash the sorted instance GPU UUIDs
|
|
66
|
+
"""
|
|
67
|
+
mlflow.log_param("instance_gpu_hash_id", instance_gpu_hash_id) # type: ignore[possibly-unbound-attribute]
|
|
68
|
+
mlflow.log_param("instance_id", self.instance_id) # type: ignore[possibly-unbound-attribute]
|
|
69
|
+
mlflow.log_param("instance_type", self.instance_type) # type: ignore[possibly-unbound-attribute]
|
|
70
|
+
mlflow.log_param("instance_ipv4", self.instance_ipv4) # type: ignore[possibly-unbound-attribute]
|
|
71
|
+
mlflow.log_param("instance_hostname", self.instance_hostname) # type: ignore[possibly-unbound-attribute]
|
|
72
|
+
|
|
73
|
+
# mlflow param has 300 limit, seperate the gpu_info to serial_number and uuid
|
|
74
|
+
mlflow.log_param("gpu_uuid", {k: v["uuid"] for k, v in self.gpu_info.items()}) # type: ignore[possibly-unbound-attribute]
|
|
75
|
+
mlflow.log_param("gpu_serial_number", {k: v["serial"] for k, v in self.gpu_info.items()}) # type: ignore[possibly-unbound-attribute]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def make_requests(url: str, token: str) -> str:
|
|
79
|
+
"""Retrieve instance metadata from AWS EC2 Instance Metadata Service (IMDSv2).
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
url (str): The URL endpoint of the IMDSv2 metadata service.
|
|
83
|
+
token (str): The authentication token required for IMDSv2 requests.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
str: The retrieved instance metadata as a string.
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
requests.exceptions.RequestException: If the request fails due to network
|
|
90
|
+
issues, invalid URL, or failed response status.
|
|
91
|
+
"""
|
|
92
|
+
try:
|
|
93
|
+
instance_info_response = requests.get(url, headers={"X-aws-ec2-metadata-token": token})
|
|
94
|
+
instance_info_response.raise_for_status()
|
|
95
|
+
instance_info = instance_info_response.text
|
|
96
|
+
|
|
97
|
+
except RequestException as e:
|
|
98
|
+
logging.error(f"Failed to retrieve instance metadata: {e}")
|
|
99
|
+
raise e
|
|
100
|
+
|
|
101
|
+
return instance_info
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def generate_gpu_uuid_hash(uuid_list: list[str]) -> str:
|
|
105
|
+
"""
|
|
106
|
+
Concatenates the UUIDs, computes a SHA-256 hash,
|
|
107
|
+
and returns the first 17 hex characters.
|
|
108
|
+
"""
|
|
109
|
+
combined_uuid = "".join(sorted(uuid_list))
|
|
110
|
+
sha_hash = hashlib.sha256(combined_uuid.encode("utf-8")).hexdigest()
|
|
111
|
+
return "g-" + sha_hash[:17] # Take the first 17 hex characters
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def fetch_gpu_info() -> tuple[dict[int | str, dict[str, Any]], str]:
|
|
115
|
+
"""
|
|
116
|
+
Retrieve GPU information from the current EC2 instance using NVIDIA Management Library (NVML).
|
|
117
|
+
|
|
118
|
+
This function initializes NVML to gather GPU details available to PyTorch,
|
|
119
|
+
including GPU UUIDs and Serial Numbers. Additionally, it generates a hash ID
|
|
120
|
+
representing all GPUs' UUIDs for easier identification.
|
|
121
|
+
|
|
122
|
+
The function logs relevant information and gracefully handles errors, shutting
|
|
123
|
+
down NVML in all scenarios.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
tuple:
|
|
127
|
+
- gpu_info (dict): A dictionary containing GPU information where keys are
|
|
128
|
+
PyTorch device indices (int) and values are dictionaries with the following keys:
|
|
129
|
+
- 'uuid' (str): The UUID of the GPU.
|
|
130
|
+
- 'serial' (str): The Serial Number of the GPU.
|
|
131
|
+
|
|
132
|
+
- instance_gpu_hash_id (str): A hash string representing the combined UUIDs of all GPUs.
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
NVMLError: If there's an issue retrieving GPU information from NVML.
|
|
136
|
+
"""
|
|
137
|
+
gpu_info: dict[int | str, dict[str, Any]] = {}
|
|
138
|
+
instance_gpu_hash_id = ""
|
|
139
|
+
try:
|
|
140
|
+
nvmlInit()
|
|
141
|
+
|
|
142
|
+
# Get number of GPUs available to PyTorch
|
|
143
|
+
num_gpus = torch.cuda.device_count()
|
|
144
|
+
|
|
145
|
+
logging.info(f"Total of {num_gpus} exist in the device")
|
|
146
|
+
|
|
147
|
+
# Iterate over each GPU by its PyTorch index
|
|
148
|
+
for i in range(num_gpus):
|
|
149
|
+
# Get GPU handle by index
|
|
150
|
+
handle = nvmlDeviceGetHandleByIndex(i)
|
|
151
|
+
# Retrieve the UUID of the GPU
|
|
152
|
+
uuid = nvmlDeviceGetUUID(handle).decode("utf-8")
|
|
153
|
+
serial = nvmlDeviceGetSerial(handle).decode("utf-8")
|
|
154
|
+
logging.info(f"PyTorch Device Index {i} - GPU UUID: {uuid} - Serial Number: {serial}")
|
|
155
|
+
|
|
156
|
+
gpu_info[i] = {
|
|
157
|
+
"uuid": uuid,
|
|
158
|
+
"serial": serial,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
instance_gpu_hash_id = generate_gpu_uuid_hash([gpu["uuid"] for gpu in gpu_info.values()])
|
|
162
|
+
logging.info(f"instance_gpu_hash_id is {instance_gpu_hash_id}")
|
|
163
|
+
|
|
164
|
+
except NVMLError as e:
|
|
165
|
+
logging.info("error when fetch informatioin of GPU")
|
|
166
|
+
raise e
|
|
167
|
+
finally:
|
|
168
|
+
nvmlShutdown()
|
|
169
|
+
|
|
170
|
+
return gpu_info, instance_gpu_hash_id
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def generate_random_string(length: int) -> str:
|
|
174
|
+
"""
|
|
175
|
+
Generate a random string of specified length containing uppercase letters,
|
|
176
|
+
lowercase letters, and digits.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
length (int): The desired length of the generated string.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
str: A randomly generated string of the specified length.
|
|
183
|
+
"""
|
|
184
|
+
characters = string.ascii_letters + string.digits # Uppercase, lowercase letters, and digits
|
|
185
|
+
return "".join(random.choice(characters) for _ in range(length))
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def generate_test_folder_name() -> str:
|
|
189
|
+
"""
|
|
190
|
+
Generate a unique test folder name using the current timestamp and a random string.
|
|
191
|
+
|
|
192
|
+
The folder name is constructed by combining the current date and time (formatted as
|
|
193
|
+
'YYYYMMDD_HHMMSS') with a randomly generated string of 6 characters consisting of
|
|
194
|
+
uppercase letters, lowercase letters, and digits.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
str: A unique test folder name.
|
|
198
|
+
|
|
199
|
+
Example:
|
|
200
|
+
>>> generate_test_folder_name()
|
|
201
|
+
'20250324_153045_A3bX7z'
|
|
202
|
+
"""
|
|
203
|
+
if AWS_BATCH_JOB_ID in os.environ:
|
|
204
|
+
test_folder_name = strip_aws_batch_id(os.getenv(AWS_BATCH_JOB_ID, "local"))
|
|
205
|
+
else:
|
|
206
|
+
test_folder_name = f"local_{datetime.now(timezone.utc).strftime('%Y-%m-%d-%H-%M-%S')}"
|
|
207
|
+
logging.info(f"current job log will be saved to folder {test_folder_name}")
|
|
208
|
+
return test_folder_name
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def strip_aws_batch_id(aws_batch_id: str) -> str:
|
|
212
|
+
"""
|
|
213
|
+
Strip the AWS Batch ID to remove any additional node information.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
aws_batch_id (str): The original AWS Batch ID, which may include a node index suffix.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
str: The stripped AWS Batch ID without any node index suffix.
|
|
220
|
+
"""
|
|
221
|
+
return aws_batch_id.split("#")[0]
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def destroy_process_group_if_initialized() -> None:
|
|
225
|
+
"""
|
|
226
|
+
Safely destroys the PyTorch distributed process group if it is initialized.
|
|
227
|
+
|
|
228
|
+
This function checks if the `torch.distributed` process group is both available
|
|
229
|
+
and initialized. If so, it calls `destroy_process_group()` and logs success.
|
|
230
|
+
Otherwise, it logs a warning. Any exceptions during the process are caught
|
|
231
|
+
and logged as errors.
|
|
232
|
+
"""
|
|
233
|
+
try:
|
|
234
|
+
if dist.is_available() and dist.is_initialized(): # type: ignore[possibly-unbound-attribute]
|
|
235
|
+
dist.destroy_process_group() # type: ignore[possibly-unbound-attribute]
|
|
236
|
+
logging.info("Process group destroyed.")
|
|
237
|
+
else:
|
|
238
|
+
logging.warning("No process group to destroy.")
|
|
239
|
+
except Exception as e:
|
|
240
|
+
logging.error(str(e))
|
|
241
|
+
logging.error("Process group can't be terminated.")
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def checkfunction_timeout_manager(func: Callable[..., None], kwargs: dict[str, Any]) -> Any:
|
|
245
|
+
"""
|
|
246
|
+
Monitor and enforce a timeout for executing a function within a separate process.
|
|
247
|
+
|
|
248
|
+
This function runs a specified function (`func`) in a separate process with
|
|
249
|
+
the provided arguments (`kwargs`). It continuously monitors the execution time
|
|
250
|
+
and terminates the process if it exceeds a defined timeout (`HEALTH_CHECK_TIMEOUT_SECS`).
|
|
251
|
+
|
|
252
|
+
The function result is returned via a multiprocessing queue. If the timeout is reached,
|
|
253
|
+
a `TimeoutError` is raised.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
func (Callable): The target function to be executed in a separate process.
|
|
257
|
+
It must accept `mlflow_run_id` and `result_queue` as its first
|
|
258
|
+
two arguments, followed by additional `kwargs`.
|
|
259
|
+
kwargs (dict): The keyword arguments to be passed to the function being monitored.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Any: The result returned by the `func` via the multiprocessing queue.
|
|
263
|
+
Raises:
|
|
264
|
+
TimeoutError: If the function exceeds the allowed timeout (`HEALTH_CHECK_TIMEOUT_SECS`).
|
|
265
|
+
"""
|
|
266
|
+
active_run = mlflow.active_run()
|
|
267
|
+
if active_run:
|
|
268
|
+
mlflow_run_id = active_run.info.run_id
|
|
269
|
+
else:
|
|
270
|
+
raise Exception("mlflow is not activated.")
|
|
271
|
+
|
|
272
|
+
result_queue: Queue[dict[str, Any] | list[int]] = mp.Queue()
|
|
273
|
+
func_process = mp.Process(target=func, args=(mlflow_run_id, result_queue), kwargs=kwargs)
|
|
274
|
+
func_process.start()
|
|
275
|
+
|
|
276
|
+
try:
|
|
277
|
+
start_time = time.perf_counter()
|
|
278
|
+
while func_process.is_alive():
|
|
279
|
+
elapsed = time.perf_counter() - start_time
|
|
280
|
+
if elapsed > HEALTH_CHECK_TIMEOUT_SECS:
|
|
281
|
+
logging.error(
|
|
282
|
+
f"[watchdog] func {func.__name__} Timeout reached ({HEALTH_CHECK_TIMEOUT_SECS} seconds). " # type: ignore[unresolved-attribute]
|
|
283
|
+
"Terminating processes."
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Terminate the processes
|
|
287
|
+
func_process.terminate()
|
|
288
|
+
|
|
289
|
+
raise TimeoutError(f"[watchdog] func {func.__name__} exceeded {HEALTH_CHECK_TIMEOUT_SECS:.3f} seconds") # type: ignore[unresolved-attribute]
|
|
290
|
+
time.sleep(1)
|
|
291
|
+
|
|
292
|
+
# If func finished before timeout, clean up the gpu_log process.
|
|
293
|
+
logging.info(f"[watchdog] func {func.__name__} completed in time, used {elapsed}s; terminating gpu_log.") # type: ignore[unresolved-attribute]
|
|
294
|
+
return result_queue.get()
|
|
295
|
+
finally:
|
|
296
|
+
func_process.terminate()
|
|
297
|
+
func_process.join()
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import os
|
|
4
|
+
import socket
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
import mlflow
|
|
8
|
+
import torch
|
|
9
|
+
import logging
|
|
10
|
+
from fkat.utils.cuda.preflight.health_check.constants import (
|
|
11
|
+
AWS_BATCH_JOB_ID,
|
|
12
|
+
AWS_BATCH_LINK,
|
|
13
|
+
MLFLOW_EXPERIMENT_NAME,
|
|
14
|
+
)
|
|
15
|
+
from fkat.utils.cuda.preflight.health_check.helpers import (
|
|
16
|
+
InstanceStats,
|
|
17
|
+
UniqueID,
|
|
18
|
+
generate_random_string,
|
|
19
|
+
strip_aws_batch_id,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def search_join_mlflow_run(run_name: str) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Searches for the most recent active MLflow run with the specified run name and joins it.
|
|
26
|
+
|
|
27
|
+
This function looks for an active MLflow run matching the given `run_name` within the current region's
|
|
28
|
+
configured experiment. If a match is found, it starts logging to that run. If no run is found,
|
|
29
|
+
it raises a RuntimeError.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
run_name (str): The name of the MLflow run to search for.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
str: The MLflow run ID of the matched run.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
RuntimeError: If no active MLflow run with the specified name is found.
|
|
39
|
+
"""
|
|
40
|
+
runs = mlflow.search_runs( # type: ignore[possibly-unbound-attribute]
|
|
41
|
+
experiment_names=[MLFLOW_EXPERIMENT_NAME.format(region=os.environ["AWS_DEFAULT_REGION"])],
|
|
42
|
+
# Output format should be a list.
|
|
43
|
+
output_format="list",
|
|
44
|
+
# Only searching for active run.
|
|
45
|
+
run_view_type=1,
|
|
46
|
+
max_results=1,
|
|
47
|
+
order_by=["start_time DESC"],
|
|
48
|
+
filter_string=f"attributes.run_name = '{run_name}'",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if runs:
|
|
52
|
+
logging.info(f"Found parent mlflow run: {runs[0].info.run_id}")
|
|
53
|
+
latest_run_id = runs[0].info.run_id
|
|
54
|
+
mlflow.start_run(latest_run_id) # type: ignore[possibly-unbound-attribute]
|
|
55
|
+
else:
|
|
56
|
+
raise RuntimeError("Can't find parent mlflow runs.")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def create_job_level_mlflow_run(job_level_mlflow_run_name: str, instance_stats: InstanceStats) -> None:
|
|
60
|
+
"""Create job level mlflow run, batch_id if batch job, local if local job.
|
|
61
|
+
This will only be create one time in a job, by rank==0. All other processes wait for 5s.
|
|
62
|
+
"""
|
|
63
|
+
# global_rank 0 process initialize the run on mlflow, index by aws_batch_id
|
|
64
|
+
if int(os.environ["RANK"]) == 0:
|
|
65
|
+
mlflow.start_run(run_name=job_level_mlflow_run_name) # type: ignore[possibly-unbound-attribute]
|
|
66
|
+
|
|
67
|
+
mlflow.log_param("instance_type", instance_stats.instance_type) # type: ignore[possibly-unbound-attribute]
|
|
68
|
+
mlflow.log_param("scan_datetime", instance_stats.scan_datetime) # type: ignore[possibly-unbound-attribute]
|
|
69
|
+
mlflow.log_param("region", instance_stats.instance_region) # type: ignore[possibly-unbound-attribute]
|
|
70
|
+
mlflow.log_param("batch_job_id", os.getenv(AWS_BATCH_JOB_ID, "local")) # type: ignore[possibly-unbound-attribute]
|
|
71
|
+
mlflow.log_param( # type: ignore[possibly-unbound-attribute]
|
|
72
|
+
"batch_job_link",
|
|
73
|
+
(
|
|
74
|
+
AWS_BATCH_LINK.format(
|
|
75
|
+
region=instance_stats.instance_region,
|
|
76
|
+
batch_id=os.environ[AWS_BATCH_JOB_ID],
|
|
77
|
+
)
|
|
78
|
+
if AWS_BATCH_JOB_ID in os.environ
|
|
79
|
+
else "local"
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
time.sleep(5)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def create_instance_level_mlflow_run(
|
|
87
|
+
unique_id: UniqueID,
|
|
88
|
+
job_level_mlflow_run_name: str,
|
|
89
|
+
instance_stats: InstanceStats,
|
|
90
|
+
) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Creates a job-level MLflow run and logs instance metadata.
|
|
93
|
+
|
|
94
|
+
This function should be called once per job (typically by the global rank 0 process).
|
|
95
|
+
It starts an MLflow run with the provided name and logs instance metadata such as type,
|
|
96
|
+
region, and batch job information. All non-zero rank processes will wait for 2 seconds
|
|
97
|
+
to ensure the run is created before proceeding.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
unique_id(str): ID of the instance.
|
|
101
|
+
job_level_mlflow_run_name (str): The name to assign to the MLflow run.
|
|
102
|
+
instance_stats (InstanceStats): An object containing metadata about the instance,
|
|
103
|
+
including type, region, and scan timestamp.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
None
|
|
107
|
+
"""
|
|
108
|
+
# Only the first process of each instance can join the mlflow run
|
|
109
|
+
if os.environ["LOCAL_RANK"] == "0":
|
|
110
|
+
# If global_rank is 0, we don't need to search for it in mlflow.
|
|
111
|
+
logging.info(f"Start the instance_level layer mlflow run in node {instance_stats}.")
|
|
112
|
+
if os.environ["RANK"] != "0":
|
|
113
|
+
search_join_mlflow_run(run_name=job_level_mlflow_run_name)
|
|
114
|
+
|
|
115
|
+
mlflow.log_param(f"instance_addr_node_{unique_id.node_rank}", socket.gethostname()) # type: ignore[possibly-unbound-attribute]
|
|
116
|
+
|
|
117
|
+
mlflow.start_run(run_name=unique_id.gpu_hash_id, nested=True, log_system_metrics=True) # type: ignore[possibly-unbound-attribute]
|
|
118
|
+
mlflow.set_system_metrics_node_id(unique_id.node_rank) # type: ignore[possibly-unbound-attribute]
|
|
119
|
+
|
|
120
|
+
mlflow.log_param("instance_id", instance_stats.instance_id) # type: ignore[possibly-unbound-attribute]
|
|
121
|
+
mlflow.log_param("instance_gpu_hash_id", unique_id.gpu_hash_id) # type: ignore[possibly-unbound-attribute]
|
|
122
|
+
mlflow.log_param("node_rank", unique_id.node_rank) # type: ignore[possibly-unbound-attribute]
|
|
123
|
+
mlflow.log_param( # type: ignore[possibly-unbound-attribute]
|
|
124
|
+
f"gpu_uuid_gpu_{os.environ['LOCAL_RANK']}",
|
|
125
|
+
instance_stats.gpu_info[torch.cuda.current_device()]["uuid"],
|
|
126
|
+
)
|
|
127
|
+
mlflow.log_param( # type: ignore[possibly-unbound-attribute]
|
|
128
|
+
f"gpu_serial_gpu_{os.environ['LOCAL_RANK']}",
|
|
129
|
+
instance_stats.gpu_info[torch.cuda.current_device()]["serial"],
|
|
130
|
+
)
|
|
131
|
+
mlflow.log_param(f"global_rank_gpu_{os.environ['LOCAL_RANK']}", os.environ["RANK"]) # type: ignore[possibly-unbound-attribute]
|
|
132
|
+
instance_stats.upload_mlflow(unique_id.gpu_hash_id)
|
|
133
|
+
else:
|
|
134
|
+
time.sleep(5)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_parent_mlflow_id() -> str:
|
|
138
|
+
"""
|
|
139
|
+
Initializes a two-layer MLflow run structure for organized metric and artifact tracking.
|
|
140
|
+
|
|
141
|
+
This function sets up the MLflow tracking URI and experiment based on the instance's region.
|
|
142
|
+
It then creates:
|
|
143
|
+
1. A **job-level run** identified by the AWS Batch Job ID (or a local fallback).
|
|
144
|
+
2. An **instance-level run** identified by the instance's GPU hash ID.
|
|
145
|
+
|
|
146
|
+
The job-level run is created once by the global rank 0 process. The instance-level run is created
|
|
147
|
+
by local rank 0 processes per node. All other local ranks on a node join the corresponding instance-level run.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
node_rank (int): The global rank of the current node (used for job-level run creation).
|
|
151
|
+
instance_gpu_hash_id (str): A unique identifier for the current instance's GPU setup.
|
|
152
|
+
instance_stats (InstanceStats): Object containing instance metadata such as type, region, and scan time.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
str: The MLflow run ID of the job-level (parent) run.
|
|
156
|
+
"""
|
|
157
|
+
if active_run := mlflow.active_run():
|
|
158
|
+
current_run_id = active_run.info.run_id
|
|
159
|
+
else:
|
|
160
|
+
raise Exception("mlflow is not activated.")
|
|
161
|
+
|
|
162
|
+
if parent_run := mlflow.get_parent_run(current_run_id): # type: ignore[possibly-unbound-attribute]
|
|
163
|
+
parent_run_id = parent_run.info.run_id
|
|
164
|
+
else:
|
|
165
|
+
raise Exception("instance level mlflow run should have a parent run.")
|
|
166
|
+
|
|
167
|
+
return parent_run_id
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def initialize_mlflow(unique_id: UniqueID, instance_stats: InstanceStats) -> str:
|
|
171
|
+
"""Initial mlflow. The MLflow run will have 2 layers, index by the following:
|
|
172
|
+
1. batch_run_id or "local_********"
|
|
173
|
+
2. Instance_gpu_hash_id.
|
|
174
|
+
|
|
175
|
+
In this way metrics/parameter/artifact can be better organized.
|
|
176
|
+
"""
|
|
177
|
+
# Set up mlflow endpoint and experiment by region
|
|
178
|
+
logging.info("Initializing MLflow")
|
|
179
|
+
|
|
180
|
+
mlflow.set_tracking_uri(uri=os.environ["MLFLOW_URI"])
|
|
181
|
+
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME.format(region=instance_stats.instance_region))
|
|
182
|
+
|
|
183
|
+
logging.info("Start the job_level layer mlflow run")
|
|
184
|
+
|
|
185
|
+
job_level_mlflow_run_name = strip_aws_batch_id(os.getenv(AWS_BATCH_JOB_ID, f"local-{generate_random_string(10)}"))
|
|
186
|
+
|
|
187
|
+
# global_rank 0 process initialize the job-level run on mlflow, name by aws_batch_id
|
|
188
|
+
create_job_level_mlflow_run(job_level_mlflow_run_name, instance_stats)
|
|
189
|
+
|
|
190
|
+
# local_rank 0 process initialize the instace-level run on mlflow, name by gpu_hash_id
|
|
191
|
+
create_instance_level_mlflow_run(unique_id, job_level_mlflow_run_name, instance_stats)
|
|
192
|
+
|
|
193
|
+
# other processes join the instace-level run
|
|
194
|
+
if os.environ["LOCAL_RANK"] != "0":
|
|
195
|
+
search_join_mlflow_run(run_name=unique_id.gpu_hash_id)
|
|
196
|
+
|
|
197
|
+
# return the job-level mlflow run id
|
|
198
|
+
return get_parent_mlflow_id()
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def end_all_mlflow_active_runs() -> None:
|
|
202
|
+
"""End all active mlflow runs."""
|
|
203
|
+
while active_run := mlflow.active_run():
|
|
204
|
+
logging.info(f"Ending run: {active_run.info.run_id}")
|
|
205
|
+
mlflow.end_run() # type: ignore[possibly-unbound-attribute]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
import time
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from collections.abc import Iterator, Mapping
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Timer(Mapping[str, float]):
|
|
9
|
+
def __init__(self) -> None:
|
|
10
|
+
self._times: dict[str, list[float]] = {}
|
|
11
|
+
|
|
12
|
+
@contextmanager
|
|
13
|
+
def __call__(self, name: str) -> Iterator[None]:
|
|
14
|
+
start = time.perf_counter()
|
|
15
|
+
try:
|
|
16
|
+
yield
|
|
17
|
+
finally:
|
|
18
|
+
end = time.perf_counter()
|
|
19
|
+
self._times.setdefault(name, []).append(1000 * (end - start))
|
|
20
|
+
|
|
21
|
+
def __getitem__(self, name: str) -> float:
|
|
22
|
+
if len(self._times[name]) == 1:
|
|
23
|
+
return self._times[name][0]
|
|
24
|
+
else:
|
|
25
|
+
return max(self._times[name])
|
|
26
|
+
|
|
27
|
+
def __iter__(self) -> Iterator[str]:
|
|
28
|
+
return iter(self._times)
|
|
29
|
+
|
|
30
|
+
def __len__(self) -> int:
|
|
31
|
+
return len(self._times)
|