fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,297 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import hashlib
4
+ import os
5
+ import random
6
+ import string
7
+ from datetime import datetime, timezone
8
+ from typing import Any
9
+ from collections.abc import Callable
10
+
11
+ import mlflow
12
+ import requests
13
+ from requests.exceptions import RequestException
14
+ import torch
15
+ from dataclasses import dataclass
16
+ import time
17
+ import logging
18
+ from fkat.utils.aws.imds import InstanceMetadata
19
+ from fkat.utils.cuda.preflight.health_check.constants import (
20
+ AWS_BATCH_JOB_ID,
21
+ HEALTH_CHECK_TIMEOUT_SECS,
22
+ )
23
+ from pynvml import (
24
+ NVMLError,
25
+ nvmlDeviceGetHandleByIndex,
26
+ nvmlDeviceGetSerial,
27
+ nvmlDeviceGetUUID,
28
+ nvmlInit,
29
+ nvmlShutdown,
30
+ )
31
+ import torch.multiprocessing as mp
32
+ import torch.distributed as dist
33
+ from torch.multiprocessing import Queue
34
+
35
+ torch.manual_seed(12345)
36
+
37
+
38
+ @dataclass
39
+ class UniqueID:
40
+ rank: int
41
+ world_size: int
42
+ local_rank: int
43
+ num_nodes: int
44
+ node_rank: int
45
+ gpu_per_node: int
46
+ gpu_hash_id: str
47
+ master_addr: str
48
+
49
+
50
+ class InstanceStats:
51
+ def __init__(self, instance_metadata: InstanceMetadata, gpu_info: dict[str | int, dict[str, Any]]) -> None:
52
+ self.instance_id = instance_metadata.instance_id
53
+ self.instance_type = instance_metadata.instance_type
54
+ self.instance_ipv4 = instance_metadata.local_ipv4
55
+ self.instance_hostname = instance_metadata.hostname
56
+ self.instance_region = instance_metadata.region
57
+ self.scan_datetime = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
58
+ self.gpu_info = gpu_info
59
+ self.healthy = True
60
+
61
+ def upload_mlflow(self, instance_gpu_hash_id: str) -> None:
62
+ """Upload instance stats to mlflow.
63
+
64
+ Args:
65
+ instance_gpu_hash_id (str): hash the sorted instance GPU UUIDs
66
+ """
67
+ mlflow.log_param("instance_gpu_hash_id", instance_gpu_hash_id) # type: ignore[possibly-unbound-attribute]
68
+ mlflow.log_param("instance_id", self.instance_id) # type: ignore[possibly-unbound-attribute]
69
+ mlflow.log_param("instance_type", self.instance_type) # type: ignore[possibly-unbound-attribute]
70
+ mlflow.log_param("instance_ipv4", self.instance_ipv4) # type: ignore[possibly-unbound-attribute]
71
+ mlflow.log_param("instance_hostname", self.instance_hostname) # type: ignore[possibly-unbound-attribute]
72
+
73
+ # mlflow param has 300 limit, seperate the gpu_info to serial_number and uuid
74
+ mlflow.log_param("gpu_uuid", {k: v["uuid"] for k, v in self.gpu_info.items()}) # type: ignore[possibly-unbound-attribute]
75
+ mlflow.log_param("gpu_serial_number", {k: v["serial"] for k, v in self.gpu_info.items()}) # type: ignore[possibly-unbound-attribute]
76
+
77
+
78
+ def make_requests(url: str, token: str) -> str:
79
+ """Retrieve instance metadata from AWS EC2 Instance Metadata Service (IMDSv2).
80
+
81
+ Args:
82
+ url (str): The URL endpoint of the IMDSv2 metadata service.
83
+ token (str): The authentication token required for IMDSv2 requests.
84
+
85
+ Returns:
86
+ str: The retrieved instance metadata as a string.
87
+
88
+ Raises:
89
+ requests.exceptions.RequestException: If the request fails due to network
90
+ issues, invalid URL, or failed response status.
91
+ """
92
+ try:
93
+ instance_info_response = requests.get(url, headers={"X-aws-ec2-metadata-token": token})
94
+ instance_info_response.raise_for_status()
95
+ instance_info = instance_info_response.text
96
+
97
+ except RequestException as e:
98
+ logging.error(f"Failed to retrieve instance metadata: {e}")
99
+ raise e
100
+
101
+ return instance_info
102
+
103
+
104
+ def generate_gpu_uuid_hash(uuid_list: list[str]) -> str:
105
+ """
106
+ Concatenates the UUIDs, computes a SHA-256 hash,
107
+ and returns the first 17 hex characters.
108
+ """
109
+ combined_uuid = "".join(sorted(uuid_list))
110
+ sha_hash = hashlib.sha256(combined_uuid.encode("utf-8")).hexdigest()
111
+ return "g-" + sha_hash[:17] # Take the first 17 hex characters
112
+
113
+
114
+ def fetch_gpu_info() -> tuple[dict[int | str, dict[str, Any]], str]:
115
+ """
116
+ Retrieve GPU information from the current EC2 instance using NVIDIA Management Library (NVML).
117
+
118
+ This function initializes NVML to gather GPU details available to PyTorch,
119
+ including GPU UUIDs and Serial Numbers. Additionally, it generates a hash ID
120
+ representing all GPUs' UUIDs for easier identification.
121
+
122
+ The function logs relevant information and gracefully handles errors, shutting
123
+ down NVML in all scenarios.
124
+
125
+ Returns:
126
+ tuple:
127
+ - gpu_info (dict): A dictionary containing GPU information where keys are
128
+ PyTorch device indices (int) and values are dictionaries with the following keys:
129
+ - 'uuid' (str): The UUID of the GPU.
130
+ - 'serial' (str): The Serial Number of the GPU.
131
+
132
+ - instance_gpu_hash_id (str): A hash string representing the combined UUIDs of all GPUs.
133
+
134
+ Raises:
135
+ NVMLError: If there's an issue retrieving GPU information from NVML.
136
+ """
137
+ gpu_info: dict[int | str, dict[str, Any]] = {}
138
+ instance_gpu_hash_id = ""
139
+ try:
140
+ nvmlInit()
141
+
142
+ # Get number of GPUs available to PyTorch
143
+ num_gpus = torch.cuda.device_count()
144
+
145
+ logging.info(f"Total of {num_gpus} exist in the device")
146
+
147
+ # Iterate over each GPU by its PyTorch index
148
+ for i in range(num_gpus):
149
+ # Get GPU handle by index
150
+ handle = nvmlDeviceGetHandleByIndex(i)
151
+ # Retrieve the UUID of the GPU
152
+ uuid = nvmlDeviceGetUUID(handle).decode("utf-8")
153
+ serial = nvmlDeviceGetSerial(handle).decode("utf-8")
154
+ logging.info(f"PyTorch Device Index {i} - GPU UUID: {uuid} - Serial Number: {serial}")
155
+
156
+ gpu_info[i] = {
157
+ "uuid": uuid,
158
+ "serial": serial,
159
+ }
160
+
161
+ instance_gpu_hash_id = generate_gpu_uuid_hash([gpu["uuid"] for gpu in gpu_info.values()])
162
+ logging.info(f"instance_gpu_hash_id is {instance_gpu_hash_id}")
163
+
164
+ except NVMLError as e:
165
+ logging.info("error when fetch informatioin of GPU")
166
+ raise e
167
+ finally:
168
+ nvmlShutdown()
169
+
170
+ return gpu_info, instance_gpu_hash_id
171
+
172
+
173
+ def generate_random_string(length: int) -> str:
174
+ """
175
+ Generate a random string of specified length containing uppercase letters,
176
+ lowercase letters, and digits.
177
+
178
+ Args:
179
+ length (int): The desired length of the generated string.
180
+
181
+ Returns:
182
+ str: A randomly generated string of the specified length.
183
+ """
184
+ characters = string.ascii_letters + string.digits # Uppercase, lowercase letters, and digits
185
+ return "".join(random.choice(characters) for _ in range(length))
186
+
187
+
188
+ def generate_test_folder_name() -> str:
189
+ """
190
+ Generate a unique test folder name using the current timestamp and a random string.
191
+
192
+ The folder name is constructed by combining the current date and time (formatted as
193
+ 'YYYYMMDD_HHMMSS') with a randomly generated string of 6 characters consisting of
194
+ uppercase letters, lowercase letters, and digits.
195
+
196
+ Returns:
197
+ str: A unique test folder name.
198
+
199
+ Example:
200
+ >>> generate_test_folder_name()
201
+ '20250324_153045_A3bX7z'
202
+ """
203
+ if AWS_BATCH_JOB_ID in os.environ:
204
+ test_folder_name = strip_aws_batch_id(os.getenv(AWS_BATCH_JOB_ID, "local"))
205
+ else:
206
+ test_folder_name = f"local_{datetime.now(timezone.utc).strftime('%Y-%m-%d-%H-%M-%S')}"
207
+ logging.info(f"current job log will be saved to folder {test_folder_name}")
208
+ return test_folder_name
209
+
210
+
211
+ def strip_aws_batch_id(aws_batch_id: str) -> str:
212
+ """
213
+ Strip the AWS Batch ID to remove any additional node information.
214
+
215
+ Args:
216
+ aws_batch_id (str): The original AWS Batch ID, which may include a node index suffix.
217
+
218
+ Returns:
219
+ str: The stripped AWS Batch ID without any node index suffix.
220
+ """
221
+ return aws_batch_id.split("#")[0]
222
+
223
+
224
+ def destroy_process_group_if_initialized() -> None:
225
+ """
226
+ Safely destroys the PyTorch distributed process group if it is initialized.
227
+
228
+ This function checks if the `torch.distributed` process group is both available
229
+ and initialized. If so, it calls `destroy_process_group()` and logs success.
230
+ Otherwise, it logs a warning. Any exceptions during the process are caught
231
+ and logged as errors.
232
+ """
233
+ try:
234
+ if dist.is_available() and dist.is_initialized(): # type: ignore[possibly-unbound-attribute]
235
+ dist.destroy_process_group() # type: ignore[possibly-unbound-attribute]
236
+ logging.info("Process group destroyed.")
237
+ else:
238
+ logging.warning("No process group to destroy.")
239
+ except Exception as e:
240
+ logging.error(str(e))
241
+ logging.error("Process group can't be terminated.")
242
+
243
+
244
+ def checkfunction_timeout_manager(func: Callable[..., None], kwargs: dict[str, Any]) -> Any:
245
+ """
246
+ Monitor and enforce a timeout for executing a function within a separate process.
247
+
248
+ This function runs a specified function (`func`) in a separate process with
249
+ the provided arguments (`kwargs`). It continuously monitors the execution time
250
+ and terminates the process if it exceeds a defined timeout (`HEALTH_CHECK_TIMEOUT_SECS`).
251
+
252
+ The function result is returned via a multiprocessing queue. If the timeout is reached,
253
+ a `TimeoutError` is raised.
254
+
255
+ Args:
256
+ func (Callable): The target function to be executed in a separate process.
257
+ It must accept `mlflow_run_id` and `result_queue` as its first
258
+ two arguments, followed by additional `kwargs`.
259
+ kwargs (dict): The keyword arguments to be passed to the function being monitored.
260
+
261
+ Returns:
262
+ Any: The result returned by the `func` via the multiprocessing queue.
263
+ Raises:
264
+ TimeoutError: If the function exceeds the allowed timeout (`HEALTH_CHECK_TIMEOUT_SECS`).
265
+ """
266
+ active_run = mlflow.active_run()
267
+ if active_run:
268
+ mlflow_run_id = active_run.info.run_id
269
+ else:
270
+ raise Exception("mlflow is not activated.")
271
+
272
+ result_queue: Queue[dict[str, Any] | list[int]] = mp.Queue()
273
+ func_process = mp.Process(target=func, args=(mlflow_run_id, result_queue), kwargs=kwargs)
274
+ func_process.start()
275
+
276
+ try:
277
+ start_time = time.perf_counter()
278
+ while func_process.is_alive():
279
+ elapsed = time.perf_counter() - start_time
280
+ if elapsed > HEALTH_CHECK_TIMEOUT_SECS:
281
+ logging.error(
282
+ f"[watchdog] func {func.__name__} Timeout reached ({HEALTH_CHECK_TIMEOUT_SECS} seconds). " # type: ignore[unresolved-attribute]
283
+ "Terminating processes."
284
+ )
285
+
286
+ # Terminate the processes
287
+ func_process.terminate()
288
+
289
+ raise TimeoutError(f"[watchdog] func {func.__name__} exceeded {HEALTH_CHECK_TIMEOUT_SECS:.3f} seconds") # type: ignore[unresolved-attribute]
290
+ time.sleep(1)
291
+
292
+ # If func finished before timeout, clean up the gpu_log process.
293
+ logging.info(f"[watchdog] func {func.__name__} completed in time, used {elapsed}s; terminating gpu_log.") # type: ignore[unresolved-attribute]
294
+ return result_queue.get()
295
+ finally:
296
+ func_process.terminate()
297
+ func_process.join()
@@ -0,0 +1,205 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import socket
5
+ import time
6
+
7
+ import mlflow
8
+ import torch
9
+ import logging
10
+ from fkat.utils.cuda.preflight.health_check.constants import (
11
+ AWS_BATCH_JOB_ID,
12
+ AWS_BATCH_LINK,
13
+ MLFLOW_EXPERIMENT_NAME,
14
+ )
15
+ from fkat.utils.cuda.preflight.health_check.helpers import (
16
+ InstanceStats,
17
+ UniqueID,
18
+ generate_random_string,
19
+ strip_aws_batch_id,
20
+ )
21
+
22
+
23
+ def search_join_mlflow_run(run_name: str) -> None:
24
+ """
25
+ Searches for the most recent active MLflow run with the specified run name and joins it.
26
+
27
+ This function looks for an active MLflow run matching the given `run_name` within the current region's
28
+ configured experiment. If a match is found, it starts logging to that run. If no run is found,
29
+ it raises a RuntimeError.
30
+
31
+ Args:
32
+ run_name (str): The name of the MLflow run to search for.
33
+
34
+ Returns:
35
+ str: The MLflow run ID of the matched run.
36
+
37
+ Raises:
38
+ RuntimeError: If no active MLflow run with the specified name is found.
39
+ """
40
+ runs = mlflow.search_runs( # type: ignore[possibly-unbound-attribute]
41
+ experiment_names=[MLFLOW_EXPERIMENT_NAME.format(region=os.environ["AWS_DEFAULT_REGION"])],
42
+ # Output format should be a list.
43
+ output_format="list",
44
+ # Only searching for active run.
45
+ run_view_type=1,
46
+ max_results=1,
47
+ order_by=["start_time DESC"],
48
+ filter_string=f"attributes.run_name = '{run_name}'",
49
+ )
50
+
51
+ if runs:
52
+ logging.info(f"Found parent mlflow run: {runs[0].info.run_id}")
53
+ latest_run_id = runs[0].info.run_id
54
+ mlflow.start_run(latest_run_id) # type: ignore[possibly-unbound-attribute]
55
+ else:
56
+ raise RuntimeError("Can't find parent mlflow runs.")
57
+
58
+
59
+ def create_job_level_mlflow_run(job_level_mlflow_run_name: str, instance_stats: InstanceStats) -> None:
60
+ """Create job level mlflow run, batch_id if batch job, local if local job.
61
+ This will only be create one time in a job, by rank==0. All other processes wait for 5s.
62
+ """
63
+ # global_rank 0 process initialize the run on mlflow, index by aws_batch_id
64
+ if int(os.environ["RANK"]) == 0:
65
+ mlflow.start_run(run_name=job_level_mlflow_run_name) # type: ignore[possibly-unbound-attribute]
66
+
67
+ mlflow.log_param("instance_type", instance_stats.instance_type) # type: ignore[possibly-unbound-attribute]
68
+ mlflow.log_param("scan_datetime", instance_stats.scan_datetime) # type: ignore[possibly-unbound-attribute]
69
+ mlflow.log_param("region", instance_stats.instance_region) # type: ignore[possibly-unbound-attribute]
70
+ mlflow.log_param("batch_job_id", os.getenv(AWS_BATCH_JOB_ID, "local")) # type: ignore[possibly-unbound-attribute]
71
+ mlflow.log_param( # type: ignore[possibly-unbound-attribute]
72
+ "batch_job_link",
73
+ (
74
+ AWS_BATCH_LINK.format(
75
+ region=instance_stats.instance_region,
76
+ batch_id=os.environ[AWS_BATCH_JOB_ID],
77
+ )
78
+ if AWS_BATCH_JOB_ID in os.environ
79
+ else "local"
80
+ ),
81
+ )
82
+ else:
83
+ time.sleep(5)
84
+
85
+
86
+ def create_instance_level_mlflow_run(
87
+ unique_id: UniqueID,
88
+ job_level_mlflow_run_name: str,
89
+ instance_stats: InstanceStats,
90
+ ) -> None:
91
+ """
92
+ Creates a job-level MLflow run and logs instance metadata.
93
+
94
+ This function should be called once per job (typically by the global rank 0 process).
95
+ It starts an MLflow run with the provided name and logs instance metadata such as type,
96
+ region, and batch job information. All non-zero rank processes will wait for 2 seconds
97
+ to ensure the run is created before proceeding.
98
+
99
+ Args:
100
+ unique_id(str): ID of the instance.
101
+ job_level_mlflow_run_name (str): The name to assign to the MLflow run.
102
+ instance_stats (InstanceStats): An object containing metadata about the instance,
103
+ including type, region, and scan timestamp.
104
+
105
+ Returns:
106
+ None
107
+ """
108
+ # Only the first process of each instance can join the mlflow run
109
+ if os.environ["LOCAL_RANK"] == "0":
110
+ # If global_rank is 0, we don't need to search for it in mlflow.
111
+ logging.info(f"Start the instance_level layer mlflow run in node {instance_stats}.")
112
+ if os.environ["RANK"] != "0":
113
+ search_join_mlflow_run(run_name=job_level_mlflow_run_name)
114
+
115
+ mlflow.log_param(f"instance_addr_node_{unique_id.node_rank}", socket.gethostname()) # type: ignore[possibly-unbound-attribute]
116
+
117
+ mlflow.start_run(run_name=unique_id.gpu_hash_id, nested=True, log_system_metrics=True) # type: ignore[possibly-unbound-attribute]
118
+ mlflow.set_system_metrics_node_id(unique_id.node_rank) # type: ignore[possibly-unbound-attribute]
119
+
120
+ mlflow.log_param("instance_id", instance_stats.instance_id) # type: ignore[possibly-unbound-attribute]
121
+ mlflow.log_param("instance_gpu_hash_id", unique_id.gpu_hash_id) # type: ignore[possibly-unbound-attribute]
122
+ mlflow.log_param("node_rank", unique_id.node_rank) # type: ignore[possibly-unbound-attribute]
123
+ mlflow.log_param( # type: ignore[possibly-unbound-attribute]
124
+ f"gpu_uuid_gpu_{os.environ['LOCAL_RANK']}",
125
+ instance_stats.gpu_info[torch.cuda.current_device()]["uuid"],
126
+ )
127
+ mlflow.log_param( # type: ignore[possibly-unbound-attribute]
128
+ f"gpu_serial_gpu_{os.environ['LOCAL_RANK']}",
129
+ instance_stats.gpu_info[torch.cuda.current_device()]["serial"],
130
+ )
131
+ mlflow.log_param(f"global_rank_gpu_{os.environ['LOCAL_RANK']}", os.environ["RANK"]) # type: ignore[possibly-unbound-attribute]
132
+ instance_stats.upload_mlflow(unique_id.gpu_hash_id)
133
+ else:
134
+ time.sleep(5)
135
+
136
+
137
+ def get_parent_mlflow_id() -> str:
138
+ """
139
+ Initializes a two-layer MLflow run structure for organized metric and artifact tracking.
140
+
141
+ This function sets up the MLflow tracking URI and experiment based on the instance's region.
142
+ It then creates:
143
+ 1. A **job-level run** identified by the AWS Batch Job ID (or a local fallback).
144
+ 2. An **instance-level run** identified by the instance's GPU hash ID.
145
+
146
+ The job-level run is created once by the global rank 0 process. The instance-level run is created
147
+ by local rank 0 processes per node. All other local ranks on a node join the corresponding instance-level run.
148
+
149
+ Args:
150
+ node_rank (int): The global rank of the current node (used for job-level run creation).
151
+ instance_gpu_hash_id (str): A unique identifier for the current instance's GPU setup.
152
+ instance_stats (InstanceStats): Object containing instance metadata such as type, region, and scan time.
153
+
154
+ Returns:
155
+ str: The MLflow run ID of the job-level (parent) run.
156
+ """
157
+ if active_run := mlflow.active_run():
158
+ current_run_id = active_run.info.run_id
159
+ else:
160
+ raise Exception("mlflow is not activated.")
161
+
162
+ if parent_run := mlflow.get_parent_run(current_run_id): # type: ignore[possibly-unbound-attribute]
163
+ parent_run_id = parent_run.info.run_id
164
+ else:
165
+ raise Exception("instance level mlflow run should have a parent run.")
166
+
167
+ return parent_run_id
168
+
169
+
170
+ def initialize_mlflow(unique_id: UniqueID, instance_stats: InstanceStats) -> str:
171
+ """Initial mlflow. The MLflow run will have 2 layers, index by the following:
172
+ 1. batch_run_id or "local_********"
173
+ 2. Instance_gpu_hash_id.
174
+
175
+ In this way metrics/parameter/artifact can be better organized.
176
+ """
177
+ # Set up mlflow endpoint and experiment by region
178
+ logging.info("Initializing MLflow")
179
+
180
+ mlflow.set_tracking_uri(uri=os.environ["MLFLOW_URI"])
181
+ mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME.format(region=instance_stats.instance_region))
182
+
183
+ logging.info("Start the job_level layer mlflow run")
184
+
185
+ job_level_mlflow_run_name = strip_aws_batch_id(os.getenv(AWS_BATCH_JOB_ID, f"local-{generate_random_string(10)}"))
186
+
187
+ # global_rank 0 process initialize the job-level run on mlflow, name by aws_batch_id
188
+ create_job_level_mlflow_run(job_level_mlflow_run_name, instance_stats)
189
+
190
+ # local_rank 0 process initialize the instace-level run on mlflow, name by gpu_hash_id
191
+ create_instance_level_mlflow_run(unique_id, job_level_mlflow_run_name, instance_stats)
192
+
193
+ # other processes join the instace-level run
194
+ if os.environ["LOCAL_RANK"] != "0":
195
+ search_join_mlflow_run(run_name=unique_id.gpu_hash_id)
196
+
197
+ # return the job-level mlflow run id
198
+ return get_parent_mlflow_id()
199
+
200
+
201
+ def end_all_mlflow_active_runs() -> None:
202
+ """End all active mlflow runs."""
203
+ while active_run := mlflow.active_run():
204
+ logging.info(f"Ending run: {active_run.info.run_id}")
205
+ mlflow.end_run() # type: ignore[possibly-unbound-attribute]
@@ -0,0 +1,31 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import time
4
+ from contextlib import contextmanager
5
+ from collections.abc import Iterator, Mapping
6
+
7
+
8
+ class Timer(Mapping[str, float]):
9
+ def __init__(self) -> None:
10
+ self._times: dict[str, list[float]] = {}
11
+
12
+ @contextmanager
13
+ def __call__(self, name: str) -> Iterator[None]:
14
+ start = time.perf_counter()
15
+ try:
16
+ yield
17
+ finally:
18
+ end = time.perf_counter()
19
+ self._times.setdefault(name, []).append(1000 * (end - start))
20
+
21
+ def __getitem__(self, name: str) -> float:
22
+ if len(self._times[name]) == 1:
23
+ return self._times[name][0]
24
+ else:
25
+ return max(self._times[name])
26
+
27
+ def __iter__(self) -> Iterator[str]:
28
+ return iter(self._times)
29
+
30
+ def __len__(self) -> int:
31
+ return len(self._times)