checkpoint-engine 0.3.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ try:
2
+ from ._version import __version__
3
+ except ImportError:
4
+ __version__ = "dev"
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '0.3.0rc0'
32
+ __version_tuple__ = version_tuple = (0, 3, 0, 'rc0')
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,86 @@
1
+ import os
2
+ import re
3
+ import socket
4
+ import subprocess
5
+ from functools import lru_cache
6
+
7
+ import torch
8
+ from loguru import logger
9
+
10
+
11
+ @lru_cache(maxsize=1)
12
+ def get_ip() -> str:
13
+ try:
14
+ # try to get ip from network interface
15
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
16
+ s.connect(("8.8.8.8", 80))
17
+ return s.getsockname()[0]
18
+ except Exception as e: # noqa: BLE001
19
+ # fallback to get ip from hostname
20
+ logger.warning(
21
+ f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
22
+ )
23
+ return socket.gethostbyname(socket.gethostname())
24
+
25
+
26
+ def npu_generate_uuid() -> str:
27
+ str_pid = str(os.getpid())
28
+ npu_num = 8
29
+ try:
30
+ for npu_id in range(npu_num):
31
+ cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
32
+ result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
33
+ str_result = str(result.stdout)
34
+ if str_pid in str_result:
35
+ # In A3 server, one NPU has two chips.
36
+ match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
37
+ chip_count = int(match_chip_count.group(1))
38
+ search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
39
+ match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
40
+ chip_id = int(match_chip_id.group(1))
41
+ return f"{get_ip()}-{npu_id * chip_count + chip_id}"
42
+ raise ValueError("The current process is not running on the npu device")
43
+ except subprocess.CalledProcessError as e:
44
+ raise ValueError("The current process is not running on the npu device") from e
45
+
46
+
47
+ class DeviceManager:
48
+ def __init__(self):
49
+ self.device_type = self._detect_device_type()
50
+ self._setup_device_module()
51
+
52
+ def _is_torch_npu_available(self) -> bool:
53
+ try:
54
+ if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
55
+ return torch.npu.is_available()
56
+ else:
57
+ return False
58
+ except ImportError:
59
+ return False
60
+
61
+ def _detect_device_type(self) -> str:
62
+ if self._is_torch_npu_available():
63
+ return "npu"
64
+ elif torch.cuda.is_available():
65
+ return "cuda"
66
+ else:
67
+ raise TypeError("The current device type is not supported")
68
+
69
+ def _setup_device_module(self):
70
+ if self.device_type == "npu":
71
+ import torch_npu
72
+
73
+ self.device_module = torch_npu.npu
74
+ elif self.device_type == "cuda":
75
+ self.device_module = torch.cuda
76
+ else:
77
+ raise TypeError("The current device type is not supported")
78
+
79
+ @property
80
+ def backend(self) -> str:
81
+ if self.device_type == "npu":
82
+ return "hccl"
83
+ elif self.device_type == "cuda":
84
+ return "nccl"
85
+ else:
86
+ raise TypeError("The current device type is not supported")