checkpoint-engine 0.3.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/__init__.py +4 -0
- checkpoint_engine/_version.py +34 -0
- checkpoint_engine/device_utils.py +86 -0
- checkpoint_engine/ps.py +1576 -0
- checkpoint_engine/worker.py +168 -0
- checkpoint_engine-0.3.0rc0.dist-info/METADATA +236 -0
- checkpoint_engine-0.3.0rc0.dist-info/RECORD +10 -0
- checkpoint_engine-0.3.0rc0.dist-info/WHEEL +5 -0
- checkpoint_engine-0.3.0rc0.dist-info/licenses/LICENCE +21 -0
- checkpoint_engine-0.3.0rc0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"__version__",
|
|
6
|
+
"__version_tuple__",
|
|
7
|
+
"version",
|
|
8
|
+
"version_tuple",
|
|
9
|
+
"__commit_id__",
|
|
10
|
+
"commit_id",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
TYPE_CHECKING = False
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
19
|
+
COMMIT_ID = Union[str, None]
|
|
20
|
+
else:
|
|
21
|
+
VERSION_TUPLE = object
|
|
22
|
+
COMMIT_ID = object
|
|
23
|
+
|
|
24
|
+
version: str
|
|
25
|
+
__version__: str
|
|
26
|
+
__version_tuple__: VERSION_TUPLE
|
|
27
|
+
version_tuple: VERSION_TUPLE
|
|
28
|
+
commit_id: COMMIT_ID
|
|
29
|
+
__commit_id__: COMMIT_ID
|
|
30
|
+
|
|
31
|
+
__version__ = version = '0.3.0rc0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 0, 'rc0')
|
|
33
|
+
|
|
34
|
+
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import socket
|
|
4
|
+
import subprocess
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=1)
|
|
12
|
+
def get_ip() -> str:
|
|
13
|
+
try:
|
|
14
|
+
# try to get ip from network interface
|
|
15
|
+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
16
|
+
s.connect(("8.8.8.8", 80))
|
|
17
|
+
return s.getsockname()[0]
|
|
18
|
+
except Exception as e: # noqa: BLE001
|
|
19
|
+
# fallback to get ip from hostname
|
|
20
|
+
logger.warning(
|
|
21
|
+
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
|
|
22
|
+
)
|
|
23
|
+
return socket.gethostbyname(socket.gethostname())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def npu_generate_uuid() -> str:
|
|
27
|
+
str_pid = str(os.getpid())
|
|
28
|
+
npu_num = 8
|
|
29
|
+
try:
|
|
30
|
+
for npu_id in range(npu_num):
|
|
31
|
+
cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
|
|
32
|
+
result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
|
|
33
|
+
str_result = str(result.stdout)
|
|
34
|
+
if str_pid in str_result:
|
|
35
|
+
# In A3 server, one NPU has two chips.
|
|
36
|
+
match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
|
|
37
|
+
chip_count = int(match_chip_count.group(1))
|
|
38
|
+
search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
|
|
39
|
+
match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
|
|
40
|
+
chip_id = int(match_chip_id.group(1))
|
|
41
|
+
return f"{get_ip()}-{npu_id * chip_count + chip_id}"
|
|
42
|
+
raise ValueError("The current process is not running on the npu device")
|
|
43
|
+
except subprocess.CalledProcessError as e:
|
|
44
|
+
raise ValueError("The current process is not running on the npu device") from e
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DeviceManager:
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self.device_type = self._detect_device_type()
|
|
50
|
+
self._setup_device_module()
|
|
51
|
+
|
|
52
|
+
def _is_torch_npu_available(self) -> bool:
|
|
53
|
+
try:
|
|
54
|
+
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
|
|
55
|
+
return torch.npu.is_available()
|
|
56
|
+
else:
|
|
57
|
+
return False
|
|
58
|
+
except ImportError:
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
def _detect_device_type(self) -> str:
|
|
62
|
+
if self._is_torch_npu_available():
|
|
63
|
+
return "npu"
|
|
64
|
+
elif torch.cuda.is_available():
|
|
65
|
+
return "cuda"
|
|
66
|
+
else:
|
|
67
|
+
raise TypeError("The current device type is not supported")
|
|
68
|
+
|
|
69
|
+
def _setup_device_module(self):
|
|
70
|
+
if self.device_type == "npu":
|
|
71
|
+
import torch_npu
|
|
72
|
+
|
|
73
|
+
self.device_module = torch_npu.npu
|
|
74
|
+
elif self.device_type == "cuda":
|
|
75
|
+
self.device_module = torch.cuda
|
|
76
|
+
else:
|
|
77
|
+
raise TypeError("The current device type is not supported")
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def backend(self) -> str:
|
|
81
|
+
if self.device_type == "npu":
|
|
82
|
+
return "hccl"
|
|
83
|
+
elif self.device_type == "cuda":
|
|
84
|
+
return "nccl"
|
|
85
|
+
else:
|
|
86
|
+
raise TypeError("The current device type is not supported")
|