fkat 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fkat/__init__.py +147 -0
  2. fkat/data/__init__.py +15 -0
  3. fkat/data/data_module.py +198 -0
  4. fkat/data/datasets/__init__.py +19 -0
  5. fkat/data/datasets/dict.py +78 -0
  6. fkat/data/datasets/json.py +176 -0
  7. fkat/data/datasets/map.py +90 -0
  8. fkat/data/datasets/parquet.py +242 -0
  9. fkat/data/datasets/sized.py +31 -0
  10. fkat/data/dict.py +42 -0
  11. fkat/data/samplers/__init__.py +9 -0
  12. fkat/data/samplers/dict.py +38 -0
  13. fkat/data/samplers/sized.py +16 -0
  14. fkat/data/samplers/strategies.py +68 -0
  15. fkat/data/sharded.py +718 -0
  16. fkat/data/shm.py +364 -0
  17. fkat/predict.py +32 -0
  18. fkat/py.typed +0 -0
  19. fkat/pytorch/__init__.py +3 -0
  20. fkat/pytorch/actions/__init__.py +11 -0
  21. fkat/pytorch/actions/aws/__init__.py +3 -0
  22. fkat/pytorch/actions/aws/batch.py +29 -0
  23. fkat/pytorch/actions/aws/ec2.py +61 -0
  24. fkat/pytorch/callbacks/__init__.py +2 -0
  25. fkat/pytorch/callbacks/cuda/__init__.py +16 -0
  26. fkat/pytorch/callbacks/cuda/cache.py +115 -0
  27. fkat/pytorch/callbacks/cuda/memory.py +200 -0
  28. fkat/pytorch/callbacks/cuda/nsys.py +199 -0
  29. fkat/pytorch/callbacks/cuda/nvtx.py +288 -0
  30. fkat/pytorch/callbacks/cuda/xid.py +173 -0
  31. fkat/pytorch/callbacks/debugging/__init__.py +9 -0
  32. fkat/pytorch/callbacks/debugging/introspection.py +569 -0
  33. fkat/pytorch/callbacks/debugging/optimizer.py +45 -0
  34. fkat/pytorch/callbacks/gc.py +146 -0
  35. fkat/pytorch/callbacks/loggers.py +211 -0
  36. fkat/pytorch/callbacks/logging/__init__.py +12 -0
  37. fkat/pytorch/callbacks/logging/heartbeat.py +76 -0
  38. fkat/pytorch/callbacks/logging/throughput.py +253 -0
  39. fkat/pytorch/callbacks/logging/validation_metrics.py +94 -0
  40. fkat/pytorch/callbacks/monitoring/__init__.py +14 -0
  41. fkat/pytorch/callbacks/monitoring/crash.py +162 -0
  42. fkat/pytorch/callbacks/monitoring/dp.py +130 -0
  43. fkat/pytorch/callbacks/monitoring/hardware_stats.py +135 -0
  44. fkat/pytorch/callbacks/monitoring/shutdown.py +170 -0
  45. fkat/pytorch/callbacks/profiling/__init__.py +13 -0
  46. fkat/pytorch/callbacks/profiling/flops.py +574 -0
  47. fkat/pytorch/callbacks/profiling/memray.py +212 -0
  48. fkat/pytorch/callbacks/profiling/torch.py +197 -0
  49. fkat/pytorch/callbacks/profiling/viztracer.py +197 -0
  50. fkat/pytorch/loggers.py +284 -0
  51. fkat/pytorch/schedule/__init__.py +27 -0
  52. fkat/pytorch/schedule/base.py +308 -0
  53. fkat/pytorch/schedule/mlflow.py +143 -0
  54. fkat/pytorch/utilities.py +49 -0
  55. fkat/test.py +31 -0
  56. fkat/train.py +32 -0
  57. fkat/utils/__init__.py +28 -0
  58. fkat/utils/aws/__init__.py +3 -0
  59. fkat/utils/aws/imds.py +137 -0
  60. fkat/utils/boto3.py +24 -0
  61. fkat/utils/config.py +194 -0
  62. fkat/utils/cuda/__init__.py +3 -0
  63. fkat/utils/cuda/preflight/__init__.py +3 -0
  64. fkat/utils/cuda/preflight/health_check/aws_instance_config.py +82 -0
  65. fkat/utils/cuda/preflight/health_check/constants.py +23 -0
  66. fkat/utils/cuda/preflight/health_check/ddb_client.py +82 -0
  67. fkat/utils/cuda/preflight/health_check/gpu_connection_test.py +104 -0
  68. fkat/utils/cuda/preflight/health_check/gpu_stress_test.py +122 -0
  69. fkat/utils/cuda/preflight/health_check/helpers.py +297 -0
  70. fkat/utils/cuda/preflight/health_check/logger.py +205 -0
  71. fkat/utils/cuda/preflight/health_check/timer.py +31 -0
  72. fkat/utils/cuda/preflight/run.py +560 -0
  73. fkat/utils/cuda/xid.py +48 -0
  74. fkat/utils/logging.py +28 -0
  75. fkat/utils/mlflow.py +33 -0
  76. fkat/utils/pandas.py +25 -0
  77. fkat/utils/pdb.py +84 -0
  78. fkat/utils/pool.py +81 -0
  79. fkat/utils/profiler.py +18 -0
  80. fkat/utils/pyarrow.py +21 -0
  81. fkat/utils/rng.py +27 -0
  82. fkat/utils/shm.py +184 -0
  83. fkat/validate.py +31 -0
  84. fkat-0.1.2.dist-info/METADATA +134 -0
  85. fkat-0.1.2.dist-info/RECORD +88 -0
  86. fkat-0.1.2.dist-info/WHEEL +4 -0
  87. fkat-0.1.2.dist-info/licenses/LICENSE +175 -0
  88. fkat-0.1.2.dist-info/licenses/NOTICE +1 -0
@@ -0,0 +1,82 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class InstanceBenchmarkConfig:
8
+ name: str
9
+ gpu_memory_gb: int
10
+ config_dim: int
11
+ num_loops_single: int
12
+ num_loops_multi: int
13
+ baseline_num_loops: int
14
+ baseline_single_node_latency: float
15
+ baseline_pair_nodes_latency: float
16
+
17
+
18
+ P4D_24XLARGE = InstanceBenchmarkConfig(
19
+ name="p4d.24xlarge",
20
+ gpu_memory_gb=18,
21
+ config_dim=4_000_000_000,
22
+ num_loops_single=200,
23
+ num_loops_multi=200,
24
+ baseline_num_loops=43,
25
+ baseline_single_node_latency=261,
26
+ baseline_pair_nodes_latency=772,
27
+ )
28
+
29
+ P4DE_24XLARGE = InstanceBenchmarkConfig(
30
+ name="p4de.24xlarge",
31
+ gpu_memory_gb=18,
32
+ config_dim=4_000_000_000,
33
+ num_loops_single=200,
34
+ num_loops_multi=200,
35
+ baseline_num_loops=43,
36
+ baseline_single_node_latency=261,
37
+ baseline_pair_nodes_latency=772,
38
+ )
39
+
40
+ P5_48XLARGE = InstanceBenchmarkConfig(
41
+ name="p5.48xlarge",
42
+ gpu_memory_gb=40,
43
+ config_dim=8_000_000_000,
44
+ num_loops_single=200,
45
+ num_loops_multi=200,
46
+ baseline_num_loops=24,
47
+ baseline_single_node_latency=324,
48
+ baseline_pair_nodes_latency=419,
49
+ )
50
+
51
+ P5EN_48XLARGE = InstanceBenchmarkConfig(
52
+ name="p5en.48xlarge",
53
+ gpu_memory_gb=75,
54
+ config_dim=16_700_000_000,
55
+ num_loops_single=100,
56
+ num_loops_multi=100,
57
+ baseline_num_loops=9,
58
+ baseline_single_node_latency=660,
59
+ baseline_pair_nodes_latency=780,
60
+ )
61
+
62
+ DEFAULT_INSTANCE = InstanceBenchmarkConfig(
63
+ name="default",
64
+ gpu_memory_gb=10,
65
+ config_dim=1_000,
66
+ num_loops_single=0,
67
+ num_loops_multi=0,
68
+ baseline_num_loops=0,
69
+ baseline_single_node_latency=0,
70
+ baseline_pair_nodes_latency=0,
71
+ )
72
+
73
+ INSTANCE_BENCHMARK_CONFIGS: dict[str, InstanceBenchmarkConfig] = {
74
+ cfg.name: cfg
75
+ for cfg in [
76
+ P4D_24XLARGE,
77
+ P4DE_24XLARGE,
78
+ P5_48XLARGE,
79
+ P5EN_48XLARGE,
80
+ DEFAULT_INSTANCE,
81
+ ]
82
+ }
@@ -0,0 +1,23 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ INSTANCE_HEALTH_STATUS_DDB_TABLE_NAME = "NodeHealthStatus"
4
+
5
+ # GPU stress test configs:
6
+ STRESS_TEST_MAX_RUNTIME_IN_SEC = 120
7
+ HEALTH_CHECK_TIMEOUT_SECS = 150
8
+
9
+ # Good/bad node decision factor
10
+ NUM_LOOPS_RANGE = 1
11
+ SINGLE_NODE_LATENCY_THRESHOLD_FACTOR = 1.1
12
+ PAIR_NODES_LATENCY_THRESHOLD_FACTOR = 1.1
13
+ PREFLIGHT_MLFLOW_METRIC_PREFIX = "preflight"
14
+
15
+ # Max len of a node IPv4 address
16
+ MAX_ADDR_LENGTH = 100
17
+ PASS = "pass"
18
+ FAIL = "fail"
19
+
20
+ MLFLOW_EXPERIMENT_NAME = "bad_node_detection_{region}"
21
+
22
+ AWS_BATCH_JOB_ID = "AWS_BATCH_JOB_ID"
23
+ AWS_BATCH_LINK = "https://{region}.console.aws.amazon.com/batch/home?region={region}#jobs/ec2/detail/{batch_id}"
@@ -0,0 +1,82 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
4
+
5
+ import boto3
6
+ from boto3.dynamodb.conditions import Key
7
+ from botocore.exceptions import ClientError
8
+ import logging
9
+ from typing import TYPE_CHECKING, Any
10
+ from fkat.utils.cuda.preflight.health_check.constants import INSTANCE_HEALTH_STATUS_DDB_TABLE_NAME
11
+ import datetime
12
+
13
+ if TYPE_CHECKING:
14
+ from types_boto3_dynamodb.service_resource import DynamoDBServiceResource, Table
15
+
16
+
17
+ class HealthStatusDDBClient:
18
+ _PARTITION_KEY = "instance_gpu_hash_id"
19
+ _SORT_KEY = "time_checked"
20
+ _is_initialized = False
21
+ _instance = None
22
+
23
+ def __new__(cls, *args: Any, **kwargs: Any) -> HealthStatusDDBClient:
24
+ if cls._instance is None:
25
+ cls._instance = super().__new__(cls)
26
+ return cls._instance
27
+
28
+ def __init__(self, region: str = "us-east-1") -> None:
29
+ if not self._is_initialized:
30
+ session = boto3.Session()
31
+ self.ddb_resource: DynamoDBServiceResource = session.resource("dynamodb", region_name=region) # type: ignore[assignment]
32
+ self.table: Table = self.ddb_resource.Table(INSTANCE_HEALTH_STATUS_DDB_TABLE_NAME)
33
+
34
+ self._is_initialized = True
35
+
36
+ def generate_ddb_item(
37
+ self,
38
+ instance_gpu_hash_id: str,
39
+ instance_health: bool,
40
+ gpu_stats: dict[str | int, dict[str, Any]],
41
+ batch_job_id: str,
42
+ instance_id: str,
43
+ instance_type: str,
44
+ test_result: dict[str, Any],
45
+ ) -> dict[str, Any]:
46
+ return {
47
+ "instance_gpu_hash_id": instance_gpu_hash_id,
48
+ "time_checked": datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
49
+ "batch_job_id": batch_job_id,
50
+ "instance_type": instance_type,
51
+ "gpu_info": {str(key): value for key, value in gpu_stats.items()},
52
+ "instance_id": instance_id,
53
+ "healthy": instance_health,
54
+ "test_result": test_result,
55
+ }
56
+
57
+ def put_item(self, item: dict[str, Any]) -> None:
58
+ try:
59
+ respones = self.table.put_item(Item=item, ReturnValues="ALL_OLD")
60
+ logging.info(f"Item {respones} successfully added to ddb.")
61
+ except ClientError as e:
62
+ # If the item already exists, an exception will be raised
63
+ if e.response["Error"]["Code"] == "ConditionalCheckFailedException":
64
+ logging.error("Item already exists. Duplicate insertion prevented.")
65
+ else:
66
+ logging.error(f"Error inserting item: {e.response['Error']['Message']}")
67
+
68
+ raise e
69
+
70
+ def get_item(self, partition_key: str) -> dict[str, Any] | None:
71
+ try:
72
+ response = self.table.query(
73
+ KeyConditionExpression=Key(self._PARTITION_KEY).eq(partition_key),
74
+ ScanIndexForward=False,
75
+ Limit=1,
76
+ )
77
+
78
+ logging.info(f"successfully get response {response} from table with key {partition_key}")
79
+ return response["Items"][0] if "Items" in response and response["Items"] else None
80
+ except Exception as e:
81
+ logging.error("An unexpected error occurred:", e)
82
+ raise e
@@ -0,0 +1,104 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import torch
5
+ import mlflow
6
+ from datetime import timedelta
7
+ import torch.distributed as dist
8
+ from torch.multiprocessing import Queue
9
+
10
+ import logging
11
+ from fkat.utils.cuda.preflight.health_check.timer import Timer
12
+ from fkat.utils.cuda.preflight.health_check.helpers import (
13
+ destroy_process_group_if_initialized,
14
+ )
15
+ from fkat.utils.cuda.preflight.health_check.logger import end_all_mlflow_active_runs
16
+ from fkat.utils.cuda.preflight.health_check.constants import (
17
+ MLFLOW_EXPERIMENT_NAME,
18
+ PREFLIGHT_MLFLOW_METRIC_PREFIX,
19
+ )
20
+
21
+ PG_TIMEOUT_MIN = float(os.environ.get("PG_TIMEOUT", 5))
22
+
23
+
24
+ def run_gpu_connection_test(
25
+ mlflow_run_id: str,
26
+ result_queue: Queue,
27
+ dim_items: int,
28
+ loops: int,
29
+ master_addr: str,
30
+ master_port: str,
31
+ world_size: int,
32
+ rank: int,
33
+ device_id: int | None = None,
34
+ mode: str = "single",
35
+ ) -> None:
36
+ """
37
+ Runs a GPU connectivity and communication benchmark test using NCCL and logs performance metrics to MLflow.
38
+
39
+ This function initializes a distributed process group with NCCL, performs a warm-up `all_reduce`, and
40
+ repeatedly performs `all_reduce` operations to test GPU communication latency. It records timing statistics
41
+ for each iteration, logs them to MLflow, and places the results in the provided queue.
42
+
43
+ Args:
44
+ mlflow_run_id (str): The ID of the MLflow run to log metrics under.
45
+ result_queue (Queue): A multiprocessing-safe queue where timing results are pushed.
46
+ dim_items (int): The dimension of the square tensor used for the all_reduce operation.
47
+ loops (int): The number of all_reduce iterations to run for benchmarking.
48
+ master_addr(str): new internal addr of the process group,
49
+ master_port(str): port used for the process group,
50
+ world_size(int): number of processes expected in process_group,
51
+ rank(int): RANK of the current process in the process_group,
52
+ rail (Optional[int], optional): The CUDA device ID to use for testing. If None, defaults to the current device.
53
+ mode (str, optional): Mode label to tag the MLflow metrics. Defaults to "single".
54
+
55
+ Returns:
56
+ None
57
+ """
58
+ dist.init_process_group( # type: ignore[possibly-unbound-attribute]
59
+ backend="nccl",
60
+ init_method=f"tcp://{master_addr}:{master_port}",
61
+ world_size=world_size,
62
+ rank=rank,
63
+ timeout=timedelta(minutes=PG_TIMEOUT_MIN),
64
+ )
65
+
66
+ if mlflow.active_run():
67
+ end_all_mlflow_active_runs()
68
+
69
+ mlflow.set_tracking_uri(uri=os.environ["MLFLOW_URI"])
70
+ mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME.format(region=os.environ["AWS_DEFAULT_REGION"]))
71
+ mlflow.start_run(run_id=mlflow_run_id) # type: ignore[possibly-unbound-attribute]
72
+
73
+ logging.info(f"inside run_gpu_connection_test, {dim_items}, {loops}, {device_id}")
74
+ timer = Timer()
75
+ device_id = device_id if device_id is not None else torch.cuda.current_device()
76
+
77
+ with timer("cuda"):
78
+ device = torch.device("cuda", device_id)
79
+ torch.cuda.set_device(device_id)
80
+ buffer = torch.ones((dim_items, dim_items), device=device, dtype=torch.float64)
81
+
82
+ # warmup
83
+ dist.all_reduce(buffer, op=dist.ReduceOp.AVG, async_op=False) # type: ignore[possibly-unbound-attribute]
84
+
85
+ results = []
86
+ for i in range(loops):
87
+ with timer(f"all_reduce_{i}"):
88
+ with timer("send"):
89
+ waiter = dist.all_reduce(buffer, op=dist.ReduceOp.AVG, async_op=True) # type: ignore[possibly-unbound-attribute]
90
+ with timer("sync"):
91
+ waiter.wait()
92
+ dist.barrier() # type: ignore[possibly-unbound-attribute]
93
+ with timer("stat"):
94
+ buffer_sum = buffer.sum().item() # noqa: F841
95
+
96
+ results.append(timer[f"all_reduce_{i}"])
97
+
98
+ mlflow.log_metric( # type: ignore[possibly-unbound-attribute]
99
+ f"{PREFLIGHT_MLFLOW_METRIC_PREFIX}/{mode}_node_test/latency_gpu_{os.environ['LOCAL_RANK']}",
100
+ timer[f"all_reduce_{i}"],
101
+ )
102
+
103
+ result_queue.put(results)
104
+ destroy_process_group_if_initialized()
@@ -0,0 +1,122 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+ import math
5
+ import mlflow
6
+ import time
7
+ from typing import Any
8
+ import torch
9
+ from torch.multiprocessing import Queue
10
+ import logging
11
+ from fkat.utils.cuda.preflight.health_check.constants import (
12
+ MLFLOW_EXPERIMENT_NAME,
13
+ PREFLIGHT_MLFLOW_METRIC_PREFIX,
14
+ )
15
+ from fkat.utils.cuda.preflight.health_check.logger import end_all_mlflow_active_runs
16
+
17
+
18
+ def run_gpu_stress_test(mlflow_run_id: str, result_queue: Queue, gpu_mem: int, max_runtime: int) -> None:
19
+ """
20
+ Performs a multi-GPU stress test by executing repeated matrix multiplications and inter-GPU memory transfers.
21
+
22
+ This function:
23
+ - Allocates large tensors on each GPU (assuming 8 GPUs),
24
+ - Performs repeated `matmul` operations to stress GPU compute,
25
+ - Copies results across GPUs to test memory transfer integrity,
26
+ - Verifies data correctness after each transfer,
27
+ - Logs metrics to MLflow regarding correctness and loop iterations,
28
+ - Returns a dictionary summarizing the health of each GPU via the result queue.
29
+
30
+ Args:
31
+ mlflow_run_id (str): The MLflow run ID under which metrics are logged.
32
+ result_queue (Queue): A multiprocessing-safe queue to place GPU health results.
33
+ gpu_mem (int): Approximate GPU memory (in GB) to target when allocating stress test tensors.
34
+ max_runtime (int): Maximum runtime (in seconds) to perform the stress test.
35
+
36
+ Returns:
37
+ None
38
+ """
39
+ if not torch.cuda.is_available():
40
+ result_queue.put({})
41
+ return
42
+ mlflow.set_tracking_uri(uri=os.environ["MLFLOW_URI"])
43
+ mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME.format(region=os.environ["AWS_DEFAULT_REGION"]))
44
+ if mlflow.active_run():
45
+ end_all_mlflow_active_runs()
46
+
47
+ mlflow.start_run(run_id=mlflow_run_id) # type: ignore[possibly-unbound-attribute]
48
+
49
+ # Get the array size for a square array that fills 1/4 of memory with 2 byte values
50
+ arr_size = (((gpu_mem / 4) * 10**9) / 2) ** (1 / 2)
51
+ arr_size = int(math.ceil(arr_size))
52
+ num_gpus = torch.cuda.device_count()
53
+ logging.info(f"inside run_load(), num_gpus is: {num_gpus}")
54
+ if num_gpus != 8:
55
+ result_queue.put({})
56
+ return
57
+
58
+ gpu_health: dict[str, Any] = {str(idx): "Unknown" for idx in range(num_gpus)}
59
+ gpu_health["check_record"] = []
60
+
61
+ Ts = [torch.ones(arr_size, arr_size, dtype=torch.bfloat16, device=f"cuda:{gpu_num}") for gpu_num in range(num_gpus)]
62
+ results = [
63
+ torch.zeros(arr_size, arr_size, dtype=torch.bfloat16, device=f"cuda:{gpu_num}") for gpu_num in range(num_gpus)
64
+ ]
65
+ from_others = [
66
+ torch.zeros(arr_size, arr_size, dtype=torch.bfloat16, device=f"cuda:{gpu_num}") for gpu_num in range(num_gpus)
67
+ ]
68
+
69
+ torch.manual_seed(12345)
70
+
71
+ start_time = time.time()
72
+ curr_loop_num = 0
73
+ while time.time() - start_time < max_runtime:
74
+ # Matrix multiply into result
75
+ # TODO: record the latency of each matmul
76
+ [torch.matmul(T, T, out=result) for T, result in zip(Ts, results, strict=False)]
77
+
78
+ # Move into gpu curr_loop_num away
79
+ for i in range(num_gpus):
80
+ other_gpu = (curr_loop_num % (num_gpus - 1) + i + 1) % num_gpus
81
+ other = from_others[other_gpu]
82
+ original = results[i]
83
+ other[:] = original
84
+
85
+ # Check values are correct
86
+ checks = [
87
+ (other == result).sum() == result.numel() for other, result in zip(from_others, results, strict=False)
88
+ ]
89
+
90
+ for idx, check in enumerate(checks):
91
+ if not check.item():
92
+ gpu_health[str(idx)] = "Unhealthy"
93
+
94
+ gpu_health["check_record"].append([check.item() for check in checks])
95
+
96
+ curr_loop_num += 1
97
+
98
+ mlflow.log_metric( # type: ignore[possibly-unbound-attribute]
99
+ f"{PREFLIGHT_MLFLOW_METRIC_PREFIX}/gpu_stress_test/all_gpu_calculation_is_correct",
100
+ all(checks),
101
+ )
102
+ mlflow.log_metric( # type: ignore[possibly-unbound-attribute]
103
+ f"{PREFLIGHT_MLFLOW_METRIC_PREFIX}/gpu_stress_test/num_loop_gpu_stress_test",
104
+ curr_loop_num,
105
+ )
106
+
107
+ gpu_health["num_loop"] = str(curr_loop_num)
108
+
109
+ logging.info(f"Finsihed run_gpu_stress_test. {curr_loop_num} loops ran.")
110
+
111
+ if curr_loop_num < num_gpus:
112
+ logging.info(f"Few loops seen, only {curr_loop_num}")
113
+ for idx in range(num_gpus):
114
+ gpu_health[str(idx)] = "Unknown"
115
+
116
+ gpu_health["all_gpus"] = all(gpu_health.get(str(i)) == "Healthy" for i in range(num_gpus))
117
+
118
+ # Free memory
119
+ del Ts, results, from_others
120
+ torch.cuda.empty_cache()
121
+
122
+ result_queue.put(gpu_health)