fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,560 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import time
5
+ import socket
6
+ import sys
7
+ import subprocess
8
+ import torch
9
+ import mlflow
10
+ import torch.multiprocessing as mp
11
+ import logging
12
+ from typing import Any
13
+ from fkat.utils.cuda.preflight.health_check.gpu_stress_test import run_gpu_stress_test
14
+ from fkat.utils.cuda.preflight.health_check.gpu_connection_test import run_gpu_connection_test
15
+ from fkat.utils.aws.imds import instance_metadata
16
+ from fkat.utils.cuda.preflight.health_check.helpers import (
17
+ fetch_gpu_info,
18
+ generate_test_folder_name,
19
+ InstanceStats,
20
+ strip_aws_batch_id,
21
+ checkfunction_timeout_manager,
22
+ UniqueID,
23
+ destroy_process_group_if_initialized,
24
+ )
25
+ from fkat.utils.cuda.preflight.health_check.logger import (
26
+ end_all_mlflow_active_runs,
27
+ initialize_mlflow,
28
+ )
29
+ from fkat.utils.cuda.preflight.health_check.aws_instance_config import INSTANCE_BENCHMARK_CONFIGS
30
+ from fkat.utils.cuda.preflight.health_check.ddb_client import HealthStatusDDBClient
31
+ from fkat.utils.cuda.preflight.health_check.constants import (
32
+ STRESS_TEST_MAX_RUNTIME_IN_SEC,
33
+ NUM_LOOPS_RANGE,
34
+ SINGLE_NODE_LATENCY_THRESHOLD_FACTOR,
35
+ PAIR_NODES_LATENCY_THRESHOLD_FACTOR,
36
+ AWS_BATCH_JOB_ID,
37
+ PREFLIGHT_MLFLOW_METRIC_PREFIX,
38
+ HEALTH_CHECK_TIMEOUT_SECS,
39
+ )
40
+
41
+ PG_TIMEOUT_MIN = float(os.environ.get("PG_TIMEOUT_MIN", 2))
42
+ MLFLOW_CHECK_INTERVAL_SECS = 5 # check mlflow every 5 second
43
+
44
+ test_folder_name = generate_test_folder_name()
45
+
46
+ logger = logging.Logger(__name__)
47
+
48
+
49
+ def _is_result_within_threshold(result: float, baseline: float, factor: float) -> bool:
50
+ """
51
+ Determines if a test result is within an acceptable threshold.
52
+
53
+ Args:
54
+ result (float): The measured test result.
55
+ baseline (float): The baseline or expected result value.
56
+ factor (float): The multiplicative threshold factor.
57
+
58
+ Returns:
59
+ bool: True if the result is within (baseline * factor), False otherwise.
60
+ """
61
+ threshold = baseline * factor
62
+ return result <= threshold
63
+
64
+
65
+ def _run_single_gpu_stress_test(
66
+ unique_id: UniqueID,
67
+ instance_stats: InstanceStats,
68
+ ) -> bool:
69
+ """
70
+ Executes a single GPU stress test to evaluate basic GPU performance.
71
+
72
+ If the current process is local_rank == 0, it runs the stress test and logs results to MLflow.
73
+ Other ranks wait for MLflow to log the result.
74
+
75
+ Args:
76
+ unique_id (UniqueID): Metadata of the current process containing rank and cluster topology.
77
+ instance_stats (InstanceStats): System-specific configuration and GPU info.
78
+
79
+ Returns:
80
+ bool: True if the test passes all conditions, False otherwise.
81
+ """
82
+
83
+ instance_type = instance_stats.instance_type
84
+
85
+ result = {}
86
+ is_passed = True
87
+
88
+ if unique_id.local_rank == 0:
89
+ logger.info("\n\n************** Start run_gpu_stress_test() ***********")
90
+ hostname = socket.gethostname()
91
+
92
+ try:
93
+ result = checkfunction_timeout_manager(
94
+ run_gpu_stress_test,
95
+ kwargs={
96
+ "gpu_mem": INSTANCE_BENCHMARK_CONFIGS[instance_type].gpu_memory_gb,
97
+ "max_runtime": STRESS_TEST_MAX_RUNTIME_IN_SEC,
98
+ },
99
+ )
100
+ baseline_num_loops_low_limit = (
101
+ INSTANCE_BENCHMARK_CONFIGS[instance_type].baseline_num_loops - NUM_LOOPS_RANGE
102
+ )
103
+ baseline_num_loops_up_limit = INSTANCE_BENCHMARK_CONFIGS[instance_type].baseline_num_loops + NUM_LOOPS_RANGE
104
+ is_within_loops_range = (
105
+ baseline_num_loops_low_limit <= int(result["num_loop"]) <= baseline_num_loops_up_limit
106
+ )
107
+
108
+ is_passed = is_within_loops_range and result.get("all_gpus", False)
109
+ logger.info(f"{hostname}: {result}")
110
+ except Exception as e:
111
+ is_passed = False
112
+ logger.error(f"Single GPU stree test meets error: {str(e)}")
113
+ result = {"error": str(e)}
114
+ finally:
115
+ torch.cuda.empty_cache()
116
+
117
+ mlflow.log_metric(f"{PREFLIGHT_MLFLOW_METRIC_PREFIX}/gpu_stress_test/is_test_pass", is_passed) # type: ignore[possibly-unbound-attribute]
118
+
119
+ if "error" not in result:
120
+ logger.info("Update instance stats to include test result")
121
+ instance_stats.gpu_info = {
122
+ str(instance_stats.gpu_info[key]["uuid"]): result[str(key)]
123
+ for key in instance_stats.gpu_info # type: ignore[index]
124
+ }
125
+
126
+ logger.info("\n\n************** Finish run_gpu_stress_test() ***********")
127
+ else:
128
+ max_retries = HEALTH_CHECK_TIMEOUT_SECS // MLFLOW_CHECK_INTERVAL_SECS + 1
129
+ attempt = 0
130
+ metric_name = f"{PREFLIGHT_MLFLOW_METRIC_PREFIX}/gpu_stress_test/is_test_pass"
131
+ while attempt < max_retries:
132
+ active_run = mlflow.active_run() # type: ignore[possibly-unbound-attribute]
133
+ if not active_run:
134
+ raise Exception("Can't find mlflow run.")
135
+
136
+ run_data = mlflow.get_run(active_run.info.run_id).data # type: ignore[possibly-unbound-attribute]
137
+ if metric_name not in run_data.metrics:
138
+ attempt += 1
139
+ time.sleep(MLFLOW_CHECK_INTERVAL_SECS)
140
+ else:
141
+ break # metric was found, exit loop
142
+
143
+ logger.info("preflight on local_rank check ends")
144
+
145
+ return is_passed
146
+
147
+
148
+ def _run_single_node_nvlink_test(
149
+ unique_id: UniqueID,
150
+ instance_stats: InstanceStats,
151
+ ) -> bool:
152
+ """
153
+ Runs a single-node NVLink bandwidth test to validate GPU interconnects within a node.
154
+
155
+ Args:
156
+ unique_id (UniqueID): Metadata of the current process containing rank and cluster topology.
157
+ instance_stats (InstanceStats): Configuration and GPU data for this node.
158
+
159
+ Returns:
160
+ bool: True if the NVLink test latency is within acceptable limits, False otherwise.
161
+ """
162
+ instance_type = instance_stats.instance_type
163
+ logger.info(f"\n\n************** Start nvlink_test() for unique_id {unique_id} ***********")
164
+
165
+ results = {}
166
+
167
+ try:
168
+ print(
169
+ f"start single node test, with master_addr {str(unique_id.master_addr)}, "
170
+ f"world_size {unique_id.gpu_per_node}, "
171
+ f"rank {unique_id.local_rank}, local_rank {os.environ['LOCAL_RANK']}"
172
+ )
173
+ results = checkfunction_timeout_manager(
174
+ func=run_gpu_connection_test,
175
+ kwargs={
176
+ "dim_items": int(INSTANCE_BENCHMARK_CONFIGS[instance_type].config_dim ** 0.5),
177
+ "loops": int(INSTANCE_BENCHMARK_CONFIGS[instance_type].num_loops_single),
178
+ "rail": unique_id.local_rank,
179
+ "mode": "single",
180
+ "master_addr": str(unique_id.master_addr),
181
+ "master_port": os.environ["MASTER_PORT"],
182
+ "world_size": unique_id.gpu_per_node,
183
+ "rank": unique_id.local_rank,
184
+ },
185
+ )
186
+ is_passed = _is_result_within_threshold(
187
+ result=results[-1],
188
+ baseline=INSTANCE_BENCHMARK_CONFIGS[instance_type].baseline_single_node_latency,
189
+ factor=SINGLE_NODE_LATENCY_THRESHOLD_FACTOR,
190
+ )
191
+
192
+ except Exception as e:
193
+ logger.error(f"Single node nvlink test failed with error: {str(e)}")
194
+ is_passed = False
195
+ results = {"error": str(e)}
196
+ finally:
197
+ destroy_process_group_if_initialized()
198
+
199
+ mlflow.log_metric( # type: ignore[possibly-unbound-attribute]
200
+ f"{PREFLIGHT_MLFLOW_METRIC_PREFIX}/single_node_test/is_test_pass_gpu_{unique_id.local_rank}",
201
+ is_passed,
202
+ )
203
+
204
+ logger.info(
205
+ f"\n\n************** Finish nvlink_test() for [unique_id, node_rank] is [{unique_id.__dict__}], "
206
+ f"results is [{results}] ***********"
207
+ )
208
+
209
+ return is_passed
210
+
211
+
212
+ def _run_multi_nodes_nvlink_test(
213
+ job_level_mlflow_run_id: str,
214
+ unique_id: UniqueID,
215
+ test_nodes_size: int,
216
+ instance_stats: InstanceStats,
217
+ ) -> bool:
218
+ """
219
+ Performs a multi-node NVLink test across two or more nodes to verify inter-node GPU connectivity.
220
+
221
+ Args:
222
+ job_level_mlflow_run_id (str): MLflow run ID for job-level coordination.
223
+ unique_id (UniqueID): Metadata of the current process containing rank and cluster topology.
224
+ test_nodes_size (int): Number of nodes involved in each test group (typically 2).
225
+ instance_stats (InstanceStats): GPU and instance-level metadata.
226
+
227
+ Returns:
228
+ bool: True if latency is within expected bounds, False otherwise.
229
+ """
230
+
231
+ instance_type = instance_stats.instance_type
232
+ node_pair_rank = unique_id.rank // test_nodes_size
233
+ node_pair_loc_rank = unique_id.rank % (test_nodes_size * unique_id.gpu_per_node)
234
+
235
+ job_level_mlflow_run_params = mlflow.get_run(job_level_mlflow_run_id).data.params # type: ignore[possibly-unbound-attribute]
236
+
237
+ logger.info(f"node_rank is {unique_id.node_rank} and node_pair_id is {node_pair_rank}")
238
+ logger.info(
239
+ f"\n\n************** Start {test_nodes_size}-node nvlink_test() for node {unique_id.node_rank} ***********"
240
+ )
241
+
242
+ # "MASTER_ADDR" of pg should use "node 0"'s address of this node pair, every 2 nodes would join the same pg.
243
+ new_master_addr = job_level_mlflow_run_params.get(
244
+ f"instance_addr_node_{unique_id.node_rank - unique_id.node_rank % 2}"
245
+ )
246
+
247
+ results = {}
248
+
249
+ try:
250
+ results = checkfunction_timeout_manager(
251
+ func=run_gpu_connection_test,
252
+ kwargs={
253
+ "dim_items": int(INSTANCE_BENCHMARK_CONFIGS[instance_type].config_dim ** 0.5),
254
+ "loops": int(INSTANCE_BENCHMARK_CONFIGS[instance_type].num_loops_multi),
255
+ "rail": int(os.environ["LOCAL_RANK"]),
256
+ "mode": "multi",
257
+ "master_addr": new_master_addr,
258
+ "master_port": os.environ["MASTER_PORT"],
259
+ "world_size": test_nodes_size * unique_id.gpu_per_node,
260
+ "rank": node_pair_loc_rank,
261
+ },
262
+ )
263
+ is_passed = _is_result_within_threshold(
264
+ result=results[-1],
265
+ baseline=INSTANCE_BENCHMARK_CONFIGS[instance_type].baseline_pair_nodes_latency,
266
+ factor=PAIR_NODES_LATENCY_THRESHOLD_FACTOR,
267
+ )
268
+
269
+ except Exception as e:
270
+ logger.error(f"Multi node nvlink test failed with error: {str(e)}")
271
+ is_passed = False
272
+ results = {"error": str(e)}
273
+ finally:
274
+ destroy_process_group_if_initialized()
275
+
276
+ mlflow.log_metric( # type: ignore[possibly-unbound-attribute]
277
+ f"{PREFLIGHT_MLFLOW_METRIC_PREFIX}/multi_node_test/is_test_pass_gpu_{os.environ['LOCAL_RANK']}",
278
+ is_passed,
279
+ )
280
+
281
+ logger.info(
282
+ f"\n\n************** Finish {test_nodes_size}-node nvlink_test() for [unique_id, node_rank, node_pair_id] is"
283
+ f"[{os.environ['LOCAL_RANK']}, {unique_id.node_rank}, {node_pair_rank}], result is [{results}] ***********"
284
+ )
285
+
286
+ return is_passed
287
+
288
+
289
+ def _get_upload_instance_info() -> tuple[InstanceStats, str]:
290
+ """
291
+ Retrieves instance metadata and GPU hash ID for the current node.
292
+
293
+ Returns:
294
+ tuple[InstanceStats, str]: A tuple containing instance statistics and GPU hash ID.
295
+ """
296
+ gpu_info, instance_gpu_hash_id = fetch_gpu_info()
297
+ instance_info = instance_metadata()
298
+ instance_stats = InstanceStats(instance_info, gpu_info)
299
+
300
+ return instance_stats, instance_gpu_hash_id
301
+
302
+
303
+ def fetch_node_info() -> tuple[bool | str, UniqueID, InstanceStats, str]:
304
+ """
305
+ Gathers necessary metadata for preflight health checking.
306
+
307
+ Returns:
308
+ tuple: A tuple containing:
309
+ - fetch success status (bool or error message),
310
+ - UniqueID object,
311
+ - InstanceStats object,
312
+ - Job-level MLflow run ID.
313
+ """
314
+ try:
315
+ instance_stats, instance_gpu_hash_id = _get_upload_instance_info()
316
+
317
+ unique_id = UniqueID(
318
+ rank=int(os.environ["RANK"]),
319
+ world_size=int(os.environ["WORLD_SIZE"]),
320
+ local_rank=int(os.environ["LOCAL_RANK"]),
321
+ master_addr=socket.gethostbyname(socket.gethostname()),
322
+ num_nodes=int(os.environ["GROUP_WORLD_SIZE"]),
323
+ gpu_per_node=int(os.environ["LOCAL_WORLD_SIZE"]),
324
+ gpu_hash_id=instance_gpu_hash_id,
325
+ node_rank=int(os.environ["RANK"]) // int(os.environ["LOCAL_WORLD_SIZE"]),
326
+ )
327
+
328
+ job_level_mlflow_run_id = initialize_mlflow(unique_id, instance_stats=instance_stats)
329
+
330
+ is_passed = True
331
+
332
+ logger.info("Successfully get all information about the instance.")
333
+ except Exception as e:
334
+ is_passed = False
335
+ logger.error(f"Failed to get all information about the instance. Error: {str(e)}")
336
+
337
+ return is_passed, unique_id, instance_stats, job_level_mlflow_run_id
338
+
339
+
340
+ def _log_result_to_mlflow(instance_stats: InstanceStats) -> None:
341
+ """
342
+ Logs the final instance health status and artifacts to MLflow.
343
+
344
+ Args:
345
+ instance_stats (InstanceStats): Instance data including the computed health status.
346
+ """
347
+ mlflow.log_metric( # type: ignore[possibly-unbound-attribute]
348
+ f"{PREFLIGHT_MLFLOW_METRIC_PREFIX}/isinstance_healthy",
349
+ instance_stats.healthy,
350
+ )
351
+
352
+ # Enter the job_level_run mlflow_run
353
+ mlflow.end_run() # type: ignore[possibly-unbound-attribute]
354
+ mlflow.log_artifacts(test_folder_name) # type: ignore[possibly-unbound-attribute]
355
+
356
+
357
+ def _log_result_to_ddb(all_check_result: dict[str, Any], unique_id: UniqueID, instance_stats: InstanceStats) -> None:
358
+ """
359
+ Writes the instance's health status and test results to DynamoDB.
360
+
361
+ Args:
362
+ all_check_result (dict): Dictionary containing the results of each health check.
363
+ unique_id (UniqueID): Metadata about the instance within the cluster.
364
+ instance_stats (InstanceStats): GPU and instance configuration details.
365
+ """
366
+ try:
367
+ ddb_client = HealthStatusDDBClient(region=instance_stats.instance_region)
368
+ ddb_item = ddb_client.generate_ddb_item(
369
+ instance_gpu_hash_id=unique_id.gpu_hash_id,
370
+ instance_health=instance_stats.healthy,
371
+ gpu_stats=instance_stats.gpu_info,
372
+ batch_job_id=strip_aws_batch_id(os.getenv(AWS_BATCH_JOB_ID, default="local")),
373
+ instance_type=instance_stats.instance_type,
374
+ instance_id=instance_stats.instance_id,
375
+ test_result=all_check_result,
376
+ )
377
+ logger.info(f"Writing {ddb_item} to ddb")
378
+ ddb_client.put_item(ddb_item)
379
+ except Exception as e:
380
+ logger.error(f"Failed to write instance health report to DynamoDB. Error: {str(e)}")
381
+
382
+
383
+ def log_preflight_results(all_check_result: dict[str, Any], unique_id: UniqueID, instance_stats: InstanceStats) -> None:
384
+ """
385
+ Logs the result of the health check to both MLflow and DynamoDB.
386
+
387
+ This function only runs on local_rank == 0.
388
+
389
+ Args:
390
+ all_check_result (dict): Health check results keyed by test name.
391
+ unique_id (UniqueID): Cluster context and rank information.
392
+ instance_stats (InstanceStats): Node-level configuration and test results.
393
+ """
394
+ # Only write instance health report when local rank is 0
395
+ if unique_id.local_rank == 0:
396
+ # Enter the instance_level mlflow_run
397
+ try:
398
+ _log_result_to_mlflow(instance_stats)
399
+
400
+ _log_result_to_ddb(
401
+ all_check_result=all_check_result,
402
+ unique_id=unique_id,
403
+ instance_stats=instance_stats,
404
+ )
405
+ except Exception as e:
406
+ logger.error(f"Failed to write instance health report to mlflow. Error: {str(e)}")
407
+
408
+
409
+ def preflight_health_check() -> None:
410
+ """
411
+ Performs a preflight diagnostic to validate whether the current instance is suitable for distributed training.
412
+
413
+ Steps performed:
414
+ 1. Gathers instance metadata, GPU hash ID, and cluster information.
415
+ 2. Runs a GPU stress test to verify core GPU functionality.
416
+ 3. Executes a single-node NVLink test to validate intra-node GPU connectivity.
417
+ 4. Conditionally runs a multi-node NVLink test for inter-node GPU connectivity (if node count is even and >1).
418
+ 5. Aggregates all test results and determines the node's overall health.
419
+ 6. Logs the test results and health status to MLflow and DynamoDB.
420
+ 7. Cleans up any distributed process groups and MLflow state.
421
+
422
+ Side Effects:
423
+ - Updates the instance health status in MLflow and DynamoDB.
424
+ - Logs diagnostic outputs and results.
425
+ - Delays execution based on rank and test coordination logic.
426
+
427
+ Note:
428
+ This function must be called within a properly initialized distributed environment with expected env vars:
429
+ `RANK`, `LOCAL_RANK`, `WORLD_SIZE`, `LOCAL_WORLD_SIZE`, `GROUP_WORLD_SIZE`.
430
+
431
+ Raises:
432
+ None directly, but will log and mark the instance as unhealthy if any test fails.
433
+ """
434
+ is_all_check_success = True
435
+ all_check_result: dict[str, str | int] = {}
436
+
437
+ # fetch instance info
438
+ (
439
+ all_check_result["fetch_instance_info"],
440
+ unique_id,
441
+ instance_stats,
442
+ job_level_mlflow_run_id,
443
+ ) = fetch_node_info()
444
+ logger.info(f"fetching result is {is_all_check_success}")
445
+
446
+ # Run gpu stress test
447
+ all_check_result["single_gpu_stress_test"] = all_check_result[
448
+ "fetch_instance_info"
449
+ ] and _run_single_gpu_stress_test(unique_id=unique_id, instance_stats=instance_stats)
450
+ logger.info(
451
+ f"_run_single_gpu_stress_test result is {all_check_result['single_gpu_stress_test']}, "
452
+ f"in local_rank {unique_id.local_rank}"
453
+ )
454
+
455
+ # Wait until the port is available
456
+ time.sleep(10)
457
+
458
+ # Run single-node test
459
+ all_check_result["single_node_nvlink_test"] = all_check_result[
460
+ "fetch_instance_info"
461
+ ] and _run_single_node_nvlink_test(unique_id=unique_id, instance_stats=instance_stats)
462
+ logger.info(
463
+ f"_run_single_node_nvlink_test result is {all_check_result['single_node_nvlink_test']}, "
464
+ f"in local_rank {unique_id.local_rank}"
465
+ )
466
+
467
+ # Wait until the port is available
468
+ time.sleep(10)
469
+
470
+ # Run multi-node test
471
+ if unique_id.num_nodes == 1:
472
+ logger.info("Due to nnode is 1, skipping multi_nodes_nvlink_test")
473
+ all_check_result["multi_node_nvlink_test"] = "Not Checked"
474
+ elif unique_id.num_nodes % 2 != 0:
475
+ logger.info("Due to nnode is odd, skipping multi_nodes_nvlink_test for now")
476
+ all_check_result["multi_node_nvlink_test"] = "Not Checked"
477
+ else:
478
+ all_check_result["multi_node_nvlink_test"] = all_check_result[
479
+ "fetch_instance_info"
480
+ ] and _run_multi_nodes_nvlink_test(
481
+ job_level_mlflow_run_id=job_level_mlflow_run_id,
482
+ unique_id=unique_id,
483
+ test_nodes_size=2,
484
+ instance_stats=instance_stats,
485
+ )
486
+ logger.info(
487
+ f"_run_multi_nodes_nvlink_test result is {all_check_result['multi_node_nvlink_test']}, "
488
+ f"in local_rank {unique_id.local_rank}"
489
+ )
490
+
491
+ # After we add try...except to catch error, we can log if the instance is unhealthy/healthy.
492
+ instance_stats.healthy = all(all_check_result.values())
493
+ logger.info(
494
+ f"After all the tests local rank : {unique_id.local_rank} has health status of {instance_stats.healthy}"
495
+ )
496
+
497
+ log_preflight_results(
498
+ all_check_result=all_check_result,
499
+ unique_id=unique_id,
500
+ instance_stats=instance_stats,
501
+ )
502
+
503
+ if not instance_stats.healthy:
504
+ logger.error(f"This instance is not healthy, with test result {all_check_result}")
505
+
506
+ end_all_mlflow_active_runs()
507
+ destroy_process_group_if_initialized()
508
+
509
+ logger.info("Preflight check ends successfully.")
510
+
511
+
512
+ def isolate_bad_node() -> None:
513
+ """
514
+ Checks the health status of the current instance from DynamoDB and isolates it if unhealthy.
515
+
516
+ This function:
517
+ 1. Retrieves GPU hash ID and instance metadata.
518
+ 2. Queries the health status record from DynamoDB using the GPU hash ID.
519
+ 3. If the instance was never scanned, raises an error to indicate unexpected behavior.
520
+ 4. If the instance is unhealthy, the process enters an infinite sleep to prevent further participation.
521
+ 5. If the instance is healthy, it sleeps for 15 minutes to allow other nodes to complete isolation logic.
522
+
523
+ This function is typically used in orchestration flows to quarantine failed nodes.
524
+
525
+ Raises:
526
+ RuntimeError: If the instance health record is missing in the database.
527
+ """
528
+ _, instance_gpu_hash_id = fetch_gpu_info()
529
+ instance_meta = instance_metadata()
530
+
531
+ ddb_client = HealthStatusDDBClient(instance_meta.region)
532
+
533
+ instance_stats = ddb_client.get_item(partition_key=instance_gpu_hash_id)
534
+
535
+ if not instance_stats:
536
+ raise RuntimeError(
537
+ f"Instance with instance_id {instance_meta.instance_id}, gpu_hash_id {instance_gpu_hash_id},"
538
+ "has not been scaned."
539
+ )
540
+ elif instance_stats and not instance_stats["healthy"]:
541
+ logger.info("This node is unhealthy, sleep this node forever.")
542
+ while True:
543
+ time.sleep(1)
544
+ else:
545
+ logger.info("This node is healthy, Sleep for 15 min to wait for other isolation job finish.")
546
+ time.sleep(15 * 60)
547
+
548
+
549
+ def check() -> None:
550
+ """
551
+ Executes the current script using the system Python interpreter.
552
+
553
+ Intended as a CLI entry point for basic validation or debugging.
554
+ """
555
+ subprocess.run([sys.executable, __file__])
556
+
557
+
558
+ if __name__ == "__main__":
559
+ mp.set_start_method("spawn", force=True)
560
+ preflight_health_check()
fkat/utils/cuda/xid.py ADDED
@@ -0,0 +1,48 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
4
+ import logging
5
+ import re
6
+ import subprocess
7
+ import sys
8
+ import multiprocessing
9
+
10
+ log = logging.getLogger(__name__)
11
+ log.setLevel(logging.INFO)
12
+ log.addHandler(logging.StreamHandler(sys.stdout))
13
+
14
+ XID_PAT = re.compile(r"\[(.*)\] NVRM: Xid \(.*\): (\d+),")
15
+
16
+
17
+ def detect_xid_errors(
18
+ xid_check: multiprocessing.synchronize.Event, # type: ignore[unresolved-attribute]
19
+ xid_errors: multiprocessing.Queue[set[int]],
20
+ ) -> None:
21
+ """
22
+ Detect XID errors by monitoring system logs.
23
+
24
+ Args:
25
+ xid_check: Event to trigger checking for XID errors
26
+ xid_errors: Queue to put detected XID errors
27
+ """
28
+ try:
29
+ log.info("\nChecking for Xid errors in a background process ...")
30
+ while True:
31
+ xid_check.wait()
32
+ xid_check.clear()
33
+ xids: set[int] = set()
34
+ f = subprocess.check_output("dmesg -Tc", shell=True)
35
+ lines = f.decode("utf8", errors="ignore").split("\n")
36
+ for line in lines:
37
+ res = XID_PAT.match(line)
38
+ if res:
39
+ xid = int(res.group(2))
40
+ xids.add(xid)
41
+ xid_errors.put(xids)
42
+ except Exception as e:
43
+ if hasattr(e, "returncode") and e.returncode == 1:
44
+ log.info(
45
+ "Xid monitoring requires running in privileged mode example "
46
+ "ensure privileged access is available to access dmesg"
47
+ )
48
+ log.info(f"error executing command: {e}")
fkat/utils/logging.py ADDED
@@ -0,0 +1,28 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """
4
+ Utilities to standardize and simplify usages of Python's built-in logging module.
5
+ """
6
+
7
+ import logging
8
+
9
+ from lightning.pytorch.utilities import rank_zero_only
10
+
11
+
12
+ def rank0_logger(name: str = __name__) -> logging.Logger:
13
+ """Initializes multi-GPU-friendly python command line logger."""
14
+ logger = logging.getLogger(name)
15
+
16
+ # this ensures all logging levels get marked with the rank zero decorator
17
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
18
+ for level in (
19
+ "debug",
20
+ "info",
21
+ "warning",
22
+ "error",
23
+ "exception",
24
+ "fatal",
25
+ "critical",
26
+ ):
27
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
28
+ return logger
fkat/utils/mlflow.py ADDED
@@ -0,0 +1,33 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ from lightning.pytorch.loggers import MLFlowLogger
7
+
8
+ if TYPE_CHECKING:
9
+ from lightning import Trainer
10
+
11
+ log = logging.getLogger(__name__)
12
+
13
+
14
+ def mlflow_logger(trainer: "Trainer") -> "MLFlowLogger | None":
15
+ """
16
+ Returns MLFlowLogger from trainer as constructed by PyTorch Lightning.
17
+ """
18
+ from fkat.pytorch.loggers import _is_logger_type
19
+
20
+ for logger in trainer.loggers:
21
+ if _is_logger_type(logger, "MLFlowLogger"):
22
+ return logger # type: ignore[return-value]
23
+ return None
24
+
25
+
26
+ def broadcast_mlflow_run_id(mlflow: "MLFlowLogger", trainer: "Trainer") -> None:
27
+ """
28
+ Broadcast mlflow run_id from rank0 to all ranks and setup the mlflow logger.
29
+ We assume PTL mlflow logger is only initialized with a run_id on rank0 via
30
+ logger.experiment.
31
+ """
32
+ mlflow._run_id = trainer.strategy.broadcast(mlflow.run_id, src=0)
33
+ log.debug(f"Received mlflow run_id: {mlflow.run_id}")
fkat/utils/pandas.py ADDED
@@ -0,0 +1,25 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing import Any
6
+ from collections.abc import Iterator
7
+
8
+
9
+ def iter_rows(table: Iterator[pd.DataFrame], replace_nan: bool = True) -> Iterator[dict[str, Any]]:
10
+ """
11
+ Generator function to iterate over rows of a Pandas :class:`DataFrame`\\s in chunks.
12
+
13
+ Args:
14
+ table (Iterator[pd.DataFrame]): Pandas :class:`DataFrame`\\s.
15
+ replace_nan (bool): Whether to replace NaN with None.
16
+ Defaults to `True`.
17
+ Yields:
18
+ dict[str, Any]: Dictionary representing each row.
19
+ """
20
+ for chunk in table:
21
+ if replace_nan:
22
+ chunk = chunk.replace({np.nan: None})
23
+ columns = chunk.to_dict()
24
+ for idx in sorted(columns[list(columns.keys())[0]].keys()):
25
+ yield {str(col): columns[col][idx] for col in columns}