torchruntime 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchruntime/__init__.py +2 -0
- torchruntime/__main__.py +52 -0
- torchruntime/configure.py +80 -0
- torchruntime/consts.py +5 -0
- torchruntime/device_db.py +138 -0
- torchruntime/gpu_pci_ids.db +0 -0
- torchruntime/installer.py +98 -0
- torchruntime/platform_detection.py +115 -0
- torchruntime/tests/test_configure.py +137 -0
- torchruntime/tests/test_device_db.py +50 -0
- torchruntime/tests/test_device_db_integration.py +50 -0
- torchruntime/tests/test_installer.py +109 -0
- torchruntime/tests/test_platform_detection.py +164 -0
- torchruntime-1.0.2.dist-info/LICENSE +21 -0
- torchruntime-1.0.2.dist-info/METADATA +145 -0
- torchruntime-1.0.2.dist-info/RECORD +18 -0
- torchruntime-1.0.2.dist-info/WHEEL +5 -0
- torchruntime-1.0.2.dist-info/top_level.txt +1 -0
torchruntime/__init__.py
ADDED
torchruntime/__main__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from .installer import install
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def print_usage():
|
|
5
|
+
"""Print usage information with examples."""
|
|
6
|
+
usage = """
|
|
7
|
+
Usage: python -m torchruntime <command> [arguments]
|
|
8
|
+
|
|
9
|
+
Commands:
|
|
10
|
+
install Install PyTorch packages
|
|
11
|
+
--help Show this help message
|
|
12
|
+
|
|
13
|
+
Examples:
|
|
14
|
+
python -m torchruntime install
|
|
15
|
+
python -m torchruntime install torch==2.2.0 torchvision==0.17.0
|
|
16
|
+
python -m torchruntime install torch>=2.0.0 torchaudio
|
|
17
|
+
python -m torchruntime install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0
|
|
18
|
+
|
|
19
|
+
If no packages are specified, the latest available versions
|
|
20
|
+
of torch, torchaudio and torchvision will be installed.
|
|
21
|
+
|
|
22
|
+
Version specification formats (follows pip format):
|
|
23
|
+
package==2.1.0 Exact version
|
|
24
|
+
package>=2.0.0 Minimum version
|
|
25
|
+
package<=2.2.0 Maximum version
|
|
26
|
+
package~=2.1.0 Compatible release
|
|
27
|
+
package==2.1.* Any 2.1.x version
|
|
28
|
+
package Latest version
|
|
29
|
+
"""
|
|
30
|
+
print(usage.strip())
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def main():
|
|
34
|
+
import sys
|
|
35
|
+
|
|
36
|
+
if len(sys.argv) < 2 or sys.argv[1] in ["--help", "-h"]:
|
|
37
|
+
print_usage()
|
|
38
|
+
return
|
|
39
|
+
|
|
40
|
+
command = sys.argv[1]
|
|
41
|
+
|
|
42
|
+
if command == "install":
|
|
43
|
+
# Pass all arguments after 'install' to the install function
|
|
44
|
+
package_versions = sys.argv[2:] if len(sys.argv) > 2 else None
|
|
45
|
+
install(package_versions)
|
|
46
|
+
else:
|
|
47
|
+
print(f"Unknown command: {command}")
|
|
48
|
+
print_usage()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
if __name__ == "__main__":
|
|
52
|
+
main()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from .device_db import get_discrete_gpus
|
|
4
|
+
from .platform_detection import get_torch_platform, os_name
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def init_torch():
|
|
8
|
+
discrete_gpu_infos = get_discrete_gpus()
|
|
9
|
+
torch_platform = get_torch_platform(discrete_gpu_infos)
|
|
10
|
+
|
|
11
|
+
if torch_platform.startswith("rocm"):
|
|
12
|
+
check_rocm_permissions()
|
|
13
|
+
set_rocm_env_vars(discrete_gpu_infos, torch_platform)
|
|
14
|
+
elif os_name == "Darwin":
|
|
15
|
+
_set_env_vars({"PYTORCH_ENABLE_MPS_FALLBACK": "1"})
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def check_rocm_permissions():
|
|
19
|
+
if not os.access("/dev/kfd", os.W_OK):
|
|
20
|
+
print(
|
|
21
|
+
"""#########################################################################
|
|
22
|
+
# No write access to /dev/kfd ! #
|
|
23
|
+
#########################################################################
|
|
24
|
+
|
|
25
|
+
Without this, the ROCm driver will probably not be able to initialize the GPU and torch will use the CPU for rendering.
|
|
26
|
+
|
|
27
|
+
Follow the instructions on this site to configure access to /dev/kfd:
|
|
28
|
+
https://github.com/easydiffusion/easydiffusion/wiki/AMD-on-Linux#access-permissions"""
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def set_rocm_env_vars(discrete_gpu_infos, torch_platform):
|
|
33
|
+
device_names = [device_name for *_, device_name in discrete_gpu_infos]
|
|
34
|
+
|
|
35
|
+
env = {}
|
|
36
|
+
|
|
37
|
+
# interesting reading:
|
|
38
|
+
# gfx config from: https://web.archive.org/web/20241228163540/https://llvm.org/docs/AMDGPUUsage.html#processors
|
|
39
|
+
# more info: https://web.archive.org/web/20241209013717/https://discuss.linuxcontainers.org/t/rocm-and-pytorch-on-amd-apu-or-gpu-ai/19743
|
|
40
|
+
# this thread is great for understanding the status of torch support for RDNA 1 (i.e. 5000 series): https://github.com/ROCm/ROCm/issues/2527
|
|
41
|
+
# past settings from: https://github.com/easydiffusion/easydiffusion/blob/20d77a85a1ed766ece0cc4b6a55dca003bce262c/scripts/check_modules.py#L405-L420
|
|
42
|
+
|
|
43
|
+
if any(device_name.startswith("Navi 3") for device_name in device_names):
|
|
44
|
+
print("[INFO] Applying Navi 3x settings")
|
|
45
|
+
env["HSA_OVERRIDE_GFX_VERSION"] = "11.0.0"
|
|
46
|
+
elif any(device_name.startswith("Navi 2") for device_name in device_names):
|
|
47
|
+
print("[INFO] Applying Navi 2x settings")
|
|
48
|
+
env["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"
|
|
49
|
+
elif any(device_name.startswith("Navi 1") for device_name in device_names):
|
|
50
|
+
print("[INFO] Applying Navi 1x settings")
|
|
51
|
+
env["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"
|
|
52
|
+
# env["HSA_ENABLE_SDMA"] = "0" # uncomment this if facing errors like in https://github.com/ROCm/ROCm/issues/2616
|
|
53
|
+
env["FORCE_FULL_PRECISION"] = "yes" # https://github.com/ROCm/ROCm/issues/2527
|
|
54
|
+
# FORCE_FULL_PRECISION won't be necessary once this is fixed (and torch2 wheels are released for ROCm 6.2): https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756
|
|
55
|
+
elif any(device_name.startswith("Vega 2") for device_name in device_names): # Radeon VII etc
|
|
56
|
+
print("[INFO] Applying Vega 20 settings")
|
|
57
|
+
env["HSA_OVERRIDE_GFX_VERSION"] = "9.0.6"
|
|
58
|
+
elif any(device_name.startswith("Vega 1") for device_name in device_names): # Radeon RX Vega 56 etc
|
|
59
|
+
print("[INFO] Applying Vega 10 settings")
|
|
60
|
+
env["HSA_OVERRIDE_GFX_VERSION"] = "9.0.0"
|
|
61
|
+
env["FORCE_FULL_PRECISION"] = "yes"
|
|
62
|
+
elif any(device_name.startswith("Ellesmere") for device_name in device_names): # RX 570, 580, 590, Polaris etc
|
|
63
|
+
print("[INFO] Applying Ellesmere settings")
|
|
64
|
+
env["HSA_OVERRIDE_GFX_VERSION"] = "8.0.3" # https://github.com/ROCm/ROCm/issues/1659
|
|
65
|
+
env["ROC_ENABLE_PRE_VEGA"] = "1"
|
|
66
|
+
else:
|
|
67
|
+
env["ROC_ENABLE_PRE_VEGA"] = "1"
|
|
68
|
+
print(f"[WARNING] Unrecognized AMD graphics card: {device_names}")
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
num_devices = len(device_names)
|
|
72
|
+
env["HIP_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(num_devices))
|
|
73
|
+
|
|
74
|
+
_set_env_vars(env)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _set_env_vars(env):
|
|
78
|
+
for k, v in env.items():
|
|
79
|
+
print(f"[INFO] Setting env variable {k}={v}")
|
|
80
|
+
os.environ[k] = v
|
torchruntime/consts.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import json
|
|
4
|
+
import sqlite3
|
|
5
|
+
import platform
|
|
6
|
+
import subprocess
|
|
7
|
+
|
|
8
|
+
DEVICE_DB_FILE = "gpu_pci_ids.db" # this file will only include AMD, NVIDIA and Discrete Intel GPUs
|
|
9
|
+
|
|
10
|
+
os_name = platform.system()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_windows_output():
|
|
14
|
+
try:
|
|
15
|
+
command = ["powershell", "-Command", "Get-WmiObject Win32_VideoController | ForEach-Object { $_.PNPDeviceID }"]
|
|
16
|
+
return subprocess.check_output(command, text=True, stderr=subprocess.DEVNULL)
|
|
17
|
+
except subprocess.CalledProcessError:
|
|
18
|
+
return ""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_linux_output():
|
|
22
|
+
try:
|
|
23
|
+
return subprocess.check_output(["lspci", "-nn"], text=True)
|
|
24
|
+
except FileNotFoundError:
|
|
25
|
+
return ""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_macos_output():
|
|
29
|
+
try:
|
|
30
|
+
return subprocess.check_output(["system_profiler", "-json", "SPDisplaysDataType"], text=True)
|
|
31
|
+
except subprocess.CalledProcessError:
|
|
32
|
+
return ""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def parse_windows_output(output):
|
|
36
|
+
pci_ids = set()
|
|
37
|
+
for line in output.splitlines():
|
|
38
|
+
match = re.search(r"VEN_(\w+)&DEV_(\w+)", line, re.IGNORECASE)
|
|
39
|
+
if match:
|
|
40
|
+
vendor_id = match.group(1).lower()
|
|
41
|
+
device_id = match.group(2).lower()
|
|
42
|
+
pci_ids.add((vendor_id, device_id))
|
|
43
|
+
return list(pci_ids)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def parse_linux_output(output):
|
|
47
|
+
pci_ids = set()
|
|
48
|
+
for line in output.splitlines():
|
|
49
|
+
match = re.search(r"\[(\w+):(\w+)\]", line)
|
|
50
|
+
if match:
|
|
51
|
+
vendor_id = match.group(1).lower()
|
|
52
|
+
device_id = match.group(2).lower()
|
|
53
|
+
pci_ids.add((vendor_id, device_id))
|
|
54
|
+
return list(pci_ids)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def parse_macos_output(output):
|
|
58
|
+
pci_ids = set()
|
|
59
|
+
try:
|
|
60
|
+
data = json.loads(output)
|
|
61
|
+
displays = data.get("SPDisplaysDataType", [])
|
|
62
|
+
for display in displays:
|
|
63
|
+
vendor_raw = display.get("spdisplays_vendor", "")
|
|
64
|
+
device_id_raw = display.get("spdisplays_device-id", "")
|
|
65
|
+
if device_id_raw and vendor_raw:
|
|
66
|
+
device_id = device_id_raw.replace("0x", "").lower()
|
|
67
|
+
if "Intel" in vendor_raw:
|
|
68
|
+
vendor_id = "8086"
|
|
69
|
+
else:
|
|
70
|
+
match = re.search(r"\((0x\w+)\)", vendor_raw)
|
|
71
|
+
if match:
|
|
72
|
+
vendor_id = match.group(1).replace("0x", "").lower()
|
|
73
|
+
else:
|
|
74
|
+
continue
|
|
75
|
+
pci_ids.add((vendor_id, device_id))
|
|
76
|
+
except json.JSONDecodeError:
|
|
77
|
+
pass
|
|
78
|
+
return list(pci_ids)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_pci_ids():
|
|
82
|
+
if os_name == "Windows":
|
|
83
|
+
output = get_windows_output()
|
|
84
|
+
return parse_windows_output(output)
|
|
85
|
+
elif os_name == "Linux":
|
|
86
|
+
output = get_linux_output()
|
|
87
|
+
return parse_linux_output(output)
|
|
88
|
+
elif os_name == "Darwin": # macOS
|
|
89
|
+
output = get_macos_output()
|
|
90
|
+
return parse_macos_output(output)
|
|
91
|
+
else:
|
|
92
|
+
return []
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_device_infos(pci_ids):
|
|
96
|
+
"""
|
|
97
|
+
Reads the given SQLite database file and queries the `pci_ids` table
|
|
98
|
+
for matching vendor_id and device_id. Returns a list of tuples containing
|
|
99
|
+
(vendor_id, vendor_name, device_id, device_name).
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
db_file_name (str): Path to the SQLite database file.
|
|
103
|
+
pci_ids (list of tuples): List of (vendor_id, device_id) pairs to match.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
list of tuples: List of (vendor_id, vendor_name, device_id, device_name).
|
|
107
|
+
"""
|
|
108
|
+
result = []
|
|
109
|
+
|
|
110
|
+
# Establish connection to the database
|
|
111
|
+
db_path = os.path.join(os.path.dirname(__file__), DEVICE_DB_FILE)
|
|
112
|
+
conn = sqlite3.connect(db_path)
|
|
113
|
+
cursor = conn.cursor()
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
# Create a query to retrieve matching rows
|
|
117
|
+
query = """
|
|
118
|
+
SELECT vendor_id, vendor_name, device_id, device_name
|
|
119
|
+
FROM pci_ids
|
|
120
|
+
WHERE vendor_id = ? AND device_id = ?
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
# Execute query for each (vendor_id, device_id) in pci_ids
|
|
124
|
+
for vendor_id, device_id in pci_ids:
|
|
125
|
+
cursor.execute(query, (vendor_id, device_id))
|
|
126
|
+
rows = cursor.fetchall()
|
|
127
|
+
result.extend(rows)
|
|
128
|
+
|
|
129
|
+
finally:
|
|
130
|
+
# Close the database connection
|
|
131
|
+
conn.close()
|
|
132
|
+
|
|
133
|
+
return result
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_discrete_gpus():
|
|
137
|
+
pci_ids = get_pci_ids()
|
|
138
|
+
return get_device_infos(pci_ids)
|
|
Binary file
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import sys
|
|
3
|
+
import platform
|
|
4
|
+
import subprocess
|
|
5
|
+
|
|
6
|
+
from .consts import CONTACT_LINK
|
|
7
|
+
from .device_db import get_discrete_gpus
|
|
8
|
+
from .platform_detection import get_torch_platform
|
|
9
|
+
|
|
10
|
+
os_name = platform.system()
|
|
11
|
+
|
|
12
|
+
PIP_PREFIX = [sys.executable, "-m", "pip", "install"]
|
|
13
|
+
CUDA_REGEX = re.compile(r"^cu[\d+]")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_install_commands(torch_platform, packages):
|
|
17
|
+
"""
|
|
18
|
+
Generates pip installation commands for PyTorch and related packages based on the specified platform.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
torch_platform (str): Target platform for PyTorch. Must be one of:
|
|
22
|
+
- "cpu"
|
|
23
|
+
- "cuXXX" (e.g., "cu112", "cu126")
|
|
24
|
+
- "rocmXXX" (e.g., "rocm4.2", "rocm6.2")
|
|
25
|
+
- "xpu"
|
|
26
|
+
- "directml"
|
|
27
|
+
- "ipex"
|
|
28
|
+
packages (list of str): List of package names (and optionally versions in pip format). Examples:
|
|
29
|
+
- ["torch", "torchvision"]
|
|
30
|
+
- ["torch>=2.0", "torchaudio==0.16.0"]
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
list of list of str: Each sublist contains a pip install command (excluding the `pip install` prefix).
|
|
34
|
+
Examples:
|
|
35
|
+
- [["torch", "--index-url", "https://foo.com/whl"]]
|
|
36
|
+
- [["torch-directml"], ["torch", "torchvision"]]
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If an unsupported platform is provided.
|
|
40
|
+
|
|
41
|
+
Notes:
|
|
42
|
+
- For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds.
|
|
43
|
+
- For "directml", the "torch-directml" package is returned as part of the installation commands.
|
|
44
|
+
- For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands.
|
|
45
|
+
"""
|
|
46
|
+
if not packages:
|
|
47
|
+
packages = ["torch", "torchaudio", "torchvision"]
|
|
48
|
+
|
|
49
|
+
if torch_platform == "cpu":
|
|
50
|
+
return [packages]
|
|
51
|
+
|
|
52
|
+
if CUDA_REGEX.match(torch_platform) or torch_platform.startswith("rocm"):
|
|
53
|
+
index_url = f"https://download.pytorch.org/whl/{torch_platform}"
|
|
54
|
+
return [packages + ["--index-url", index_url]]
|
|
55
|
+
|
|
56
|
+
if torch_platform == "xpu":
|
|
57
|
+
if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages):
|
|
58
|
+
print(
|
|
59
|
+
f"[WARNING] The preview build of 'xpu' on Windows currently only supports torch, not torchvision/torchaudio. "
|
|
60
|
+
f"torch-installer will instead use the nightly build, to get the 'xpu' version of torchaudio and torchvision as well. "
|
|
61
|
+
f"Please contact torch-installer if this is no longer accurate: {CONTACT_LINK}"
|
|
62
|
+
)
|
|
63
|
+
index_url = f"https://download.pytorch.org/whl/nightly/{torch_platform}"
|
|
64
|
+
else:
|
|
65
|
+
index_url = f"https://download.pytorch.org/whl/test/{torch_platform}"
|
|
66
|
+
|
|
67
|
+
return [packages + ["--index-url", index_url]]
|
|
68
|
+
|
|
69
|
+
if torch_platform == "directml":
|
|
70
|
+
return [["torch-directml"], packages]
|
|
71
|
+
|
|
72
|
+
if torch_platform == "ipex":
|
|
73
|
+
return [packages, ["intel-extension-for-pytorch"]]
|
|
74
|
+
|
|
75
|
+
raise ValueError(f"Unsupported platform: {torch_platform}")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def get_pip_commands(cmds):
|
|
79
|
+
assert not any(cmd is None for cmd in cmds)
|
|
80
|
+
return [PIP_PREFIX + cmd for cmd in cmds]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def run_commands(cmds):
|
|
84
|
+
for cmd in cmds:
|
|
85
|
+
print("> ", cmd)
|
|
86
|
+
subprocess.run(cmd)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def install(packages=[]):
|
|
90
|
+
"""
|
|
91
|
+
packages: a list of strings with package names (and optionally their versions in pip-format). e.g. ["torch", "torchvision"] or ["torch>=2.0", "torchaudio==0.16.0"]. Defaults to ["torch", "torchvision", "torchaudio"].
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
discrete_gpu_infos = get_discrete_gpus()
|
|
95
|
+
torch_platform = get_torch_platform(discrete_gpu_infos)
|
|
96
|
+
cmds = get_install_commands(torch_platform, packages)
|
|
97
|
+
cmds = get_pip_commands(cmds)
|
|
98
|
+
run_commands(cmds)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import platform
|
|
3
|
+
|
|
4
|
+
from .consts import AMD, INTEL, NVIDIA, CONTACT_LINK
|
|
5
|
+
|
|
6
|
+
os_name = platform.system()
|
|
7
|
+
arch = platform.machine().lower()
|
|
8
|
+
py_version = sys.version_info
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_torch_platform(discrete_gpu_infos):
|
|
12
|
+
"""
|
|
13
|
+
Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
discrete_gpu_infos (list of tuples): A list where each tuple represents a GPU. Each tuple contains:
|
|
17
|
+
- vendor_id (int): The vendor ID of the GPU (e.g., NVIDIA, AMD, INTEL constants).
|
|
18
|
+
- other details (e.g., model, memory, etc., not used directly in this function).
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
str: A string representing the platform to use. Possible values:
|
|
22
|
+
- "cpu": No discrete GPUs or unsupported configuration.
|
|
23
|
+
- "cuXXX": NVIDIA CUDA version (e.g., "cu124").
|
|
24
|
+
- "rocmXXX": AMD ROCm version (e.g., "rocm6.2").
|
|
25
|
+
- "directml": DirectML for AMD or Intel GPUs on Windows.
|
|
26
|
+
- "ipex": Intel Extension for PyTorch for Linux.
|
|
27
|
+
- "xpu": Intel's backend for PyTorch.
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
NotImplementedError: For unsupported architectures, OS-GPU combinations, or multiple GPU vendors.
|
|
31
|
+
Warning: Outputs warnings for deprecated Python versions or fallback configurations.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
VALID_ARCHS = {
|
|
35
|
+
"Windows": {"amd64"},
|
|
36
|
+
"Linux": {"x86_64", "aarch64"},
|
|
37
|
+
"Darwin": {"x86_64", "arm64"},
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
if arch not in VALID_ARCHS[os_name]:
|
|
41
|
+
raise NotImplementedError(
|
|
42
|
+
f"torch is not currently available for {os_name} on {arch} architecture! If this is no longer true, please contact torch-installer at {CONTACT_LINK}"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if len(discrete_gpu_infos) == 0:
|
|
46
|
+
return "cpu"
|
|
47
|
+
|
|
48
|
+
vendor_ids = set(vendor_id for vendor_id, *_ in discrete_gpu_infos)
|
|
49
|
+
|
|
50
|
+
if len(vendor_ids) > 1:
|
|
51
|
+
device_names = list(vendor_name + " " + device_name for _, vendor_name, _, device_name in discrete_gpu_infos)
|
|
52
|
+
raise NotImplementedError(
|
|
53
|
+
f"torch-installer does not currently support multiple graphics card manufacturers on the same computer: {device_names}! Please contact torch-installer at {CONTACT_LINK} with details about your hardware."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
vendor_id = vendor_ids.pop()
|
|
57
|
+
if vendor_id == AMD:
|
|
58
|
+
if os_name == "Windows":
|
|
59
|
+
return "directml"
|
|
60
|
+
elif os_name == "Linux":
|
|
61
|
+
device_names = set(device_name for *_, device_name in discrete_gpu_infos)
|
|
62
|
+
if any(
|
|
63
|
+
device_name.startswith("Navi 3") or device_name.startswith("Navi 2") for device_name in device_names
|
|
64
|
+
):
|
|
65
|
+
if py_version < (3, 9):
|
|
66
|
+
print(
|
|
67
|
+
"[WARNING] Support for Python 3.8 was dropped in ROCm 6.2. torch-installer will default to using ROCm 6.1 instead, but consider switching to a newer Python version to use the latest ROCm!"
|
|
68
|
+
)
|
|
69
|
+
return "rocm6.1"
|
|
70
|
+
return "rocm6.2"
|
|
71
|
+
elif any(device_name.startswith("Navi 1") for device_name in device_names):
|
|
72
|
+
return "rocm5.2"
|
|
73
|
+
elif any(
|
|
74
|
+
device_name.startswith("Vega 1") or device_name.startswith("Vega 2") for device_name in device_names
|
|
75
|
+
):
|
|
76
|
+
return "rocm5.7"
|
|
77
|
+
elif any(device_name.startswith("Ellesmere") for device_name in device_names):
|
|
78
|
+
return "rocm4.2"
|
|
79
|
+
|
|
80
|
+
print(
|
|
81
|
+
f"[WARNING] Unsupported AMD graphics card: {device_names}. If this is a recent graphics card (less than 8 years old), please contact torch-installer at {CONTACT_LINK} with details about your hardware."
|
|
82
|
+
)
|
|
83
|
+
return "cpu"
|
|
84
|
+
elif os_name == "Darwin":
|
|
85
|
+
return "rocm4.2"
|
|
86
|
+
elif vendor_id == NVIDIA:
|
|
87
|
+
if os_name in ("Windows", "Linux"):
|
|
88
|
+
return "cu124"
|
|
89
|
+
elif os_name == "Darwin":
|
|
90
|
+
raise NotImplementedError(
|
|
91
|
+
f"torch-installer does not currently support NVIDIA graphics cards on Macs! Please contact torch-installer at {CONTACT_LINK}"
|
|
92
|
+
)
|
|
93
|
+
elif vendor_id == INTEL:
|
|
94
|
+
if os_name == "Windows":
|
|
95
|
+
if py_version < (3, 9):
|
|
96
|
+
print(
|
|
97
|
+
"[WARNING] Support for Python 3.8 was dropped in torch 2.5, which supports a higher-performance 'xpu' backend for Intel. torch-installer will default to using 'directml' instead, but consider switching to a newer Python version to use the latest 'xpu' backend for Intel!"
|
|
98
|
+
)
|
|
99
|
+
return "directml"
|
|
100
|
+
return "xpu"
|
|
101
|
+
elif os_name == "Linux":
|
|
102
|
+
if py_version < (3, 9):
|
|
103
|
+
print(
|
|
104
|
+
"[WARNING] Support for Python 3.8 was dropped in torch 2.5, which supports a higher-performance 'xpu' backend for Intel. torch-installer will default to using 'intel-extension-for-pytorch' instead, but consider switching to a newer Python version to use the latest 'xpu' backend for Intel!"
|
|
105
|
+
)
|
|
106
|
+
return "ipex"
|
|
107
|
+
return "xpu"
|
|
108
|
+
else:
|
|
109
|
+
raise NotImplementedError(
|
|
110
|
+
f"torch-installer does not currently support Intel graphics cards on Macs! Please contact torch-installer at {CONTACT_LINK}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
print(f"Unrecognized vendor: {discrete_gpu_infos}")
|
|
114
|
+
|
|
115
|
+
return "cpu"
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pytest
|
|
3
|
+
from torchruntime.configure import set_rocm_env_vars
|
|
4
|
+
from torchruntime.consts import AMD
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def create_gpu_info(device_id, device_name):
|
|
8
|
+
return (AMD, "Advanced Micro Devices, Inc. [AMD/ATI]", device_id, device_name)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.fixture(autouse=True)
|
|
12
|
+
def clean_env():
|
|
13
|
+
# Remove relevant environment variables before each test
|
|
14
|
+
env_vars = [
|
|
15
|
+
"HSA_OVERRIDE_GFX_VERSION",
|
|
16
|
+
"HIP_VISIBLE_DEVICES",
|
|
17
|
+
"ROC_ENABLE_PRE_VEGA",
|
|
18
|
+
"HSA_ENABLE_SDMA",
|
|
19
|
+
"FORCE_FULL_PRECISION",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
# Store original values
|
|
23
|
+
original_values = {}
|
|
24
|
+
for var in env_vars:
|
|
25
|
+
if var in os.environ:
|
|
26
|
+
original_values[var] = os.environ[var]
|
|
27
|
+
del os.environ[var]
|
|
28
|
+
|
|
29
|
+
yield
|
|
30
|
+
|
|
31
|
+
# Restore original values and remove any new ones
|
|
32
|
+
for var in env_vars:
|
|
33
|
+
if var in os.environ and var not in original_values:
|
|
34
|
+
del os.environ[var]
|
|
35
|
+
elif var in original_values:
|
|
36
|
+
os.environ[var] = original_values[var]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_navi_3_settings():
|
|
40
|
+
gpus = [create_gpu_info("123", "Navi 31 XTX")]
|
|
41
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
42
|
+
|
|
43
|
+
assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "11.0.0"
|
|
44
|
+
assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
|
|
45
|
+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
|
|
46
|
+
assert "FORCE_FULL_PRECISION" not in os.environ
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_navi_2_settings():
|
|
50
|
+
gpus = [create_gpu_info("123", "Navi 21 XTX")]
|
|
51
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
52
|
+
|
|
53
|
+
assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "10.3.0"
|
|
54
|
+
assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
|
|
55
|
+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
|
|
56
|
+
assert "FORCE_FULL_PRECISION" not in os.environ
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_navi_1_settings():
|
|
60
|
+
gpus = [create_gpu_info("123", "Navi 14")]
|
|
61
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
62
|
+
|
|
63
|
+
assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "10.3.0"
|
|
64
|
+
assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
|
|
65
|
+
assert os.environ.get("FORCE_FULL_PRECISION") == "yes"
|
|
66
|
+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_vega_2_settings():
|
|
70
|
+
gpus = [create_gpu_info("123", "Vega 20 Radeon VII")]
|
|
71
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
72
|
+
|
|
73
|
+
assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "9.0.6"
|
|
74
|
+
assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
|
|
75
|
+
assert "FORCE_FULL_PRECISION" not in os.environ
|
|
76
|
+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_vega_1_settings():
|
|
80
|
+
gpus = [create_gpu_info("123", "Vega 10")]
|
|
81
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
82
|
+
|
|
83
|
+
assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "9.0.0"
|
|
84
|
+
assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
|
|
85
|
+
assert os.environ.get("FORCE_FULL_PRECISION") == "yes"
|
|
86
|
+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def test_ellesmere_settings():
|
|
90
|
+
gpus = [create_gpu_info("123", "Ellesmere RX 580")]
|
|
91
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
92
|
+
|
|
93
|
+
assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "8.0.3"
|
|
94
|
+
assert os.environ.get("ROC_ENABLE_PRE_VEGA") == "1"
|
|
95
|
+
assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
|
|
96
|
+
assert "FORCE_FULL_PRECISION" not in os.environ
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def test_unknown_gpu_settings():
|
|
100
|
+
gpus = [create_gpu_info("123", "Unknown GPU")]
|
|
101
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
102
|
+
|
|
103
|
+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
|
|
104
|
+
assert "HIP_VISIBLE_DEVICES" not in os.environ
|
|
105
|
+
assert "HSA_OVERRIDE_GFX_VERSION" not in os.environ
|
|
106
|
+
assert "FORCE_FULL_PRECISION" not in os.environ
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_multiple_gpus_same_type():
|
|
110
|
+
gpus = [create_gpu_info("123", "Navi 31 XTX"), create_gpu_info("124", "Navi 31 XT")]
|
|
111
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
112
|
+
|
|
113
|
+
assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "11.0.0"
|
|
114
|
+
assert os.environ.get("HIP_VISIBLE_DEVICES") == "0,1"
|
|
115
|
+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
|
|
116
|
+
assert "FORCE_FULL_PRECISION" not in os.environ
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_multiple_gpus_mixed_types():
|
|
120
|
+
gpus = [create_gpu_info("123", "Navi 31 XTX"), create_gpu_info("124", "Navi 21 XT")]
|
|
121
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
122
|
+
|
|
123
|
+
# Should use Navi 3 settings since at least one GPU is Navi 3
|
|
124
|
+
assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "11.0.0"
|
|
125
|
+
assert os.environ.get("HIP_VISIBLE_DEVICES") == "0,1"
|
|
126
|
+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
|
|
127
|
+
assert "FORCE_FULL_PRECISION" not in os.environ
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_empty_gpu_list():
|
|
131
|
+
gpus = []
|
|
132
|
+
set_rocm_env_vars(gpus, "rocm6.2")
|
|
133
|
+
|
|
134
|
+
assert "ROC_ENABLE_PRE_VEGA" not in os.environ
|
|
135
|
+
assert "HIP_VISIBLE_DEVICES" not in os.environ
|
|
136
|
+
assert "HSA_OVERRIDE_GFX_VERSION" not in os.environ
|
|
137
|
+
assert "FORCE_FULL_PRECISION" not in os.environ
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from torchruntime.device_db import parse_windows_output, parse_linux_output, parse_macos_output
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def test_parse_windows_output():
|
|
5
|
+
output = """
|
|
6
|
+
PCI\\VEN_8086&DEV_591B&SUBSYS_2212103C&REV_04
|
|
7
|
+
PCI\\VEN_10DE&DEV_1C82&SUBSYS_37131462&REV_A1
|
|
8
|
+
PCI\\VEN_10DE&DEV_2504&SUBSYS_881D1043&REV_A1\\4&22AF55FA&0&0008
|
|
9
|
+
"""
|
|
10
|
+
expected = [("8086", "591b"), ("10de", "1c82"), ("10de", ("2504"))]
|
|
11
|
+
assert sorted(parse_windows_output(output)) == sorted(expected)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_parse_linux_output():
|
|
15
|
+
output = """
|
|
16
|
+
00:02.0 VGA compatible controller: Intel Corporation HD Graphics 620 (rev 02) [8086:5916]
|
|
17
|
+
01:00.0 3D controller: NVIDIA Corporation GP108M [GeForce MX150] (rev a1) [10de:1d10]
|
|
18
|
+
"""
|
|
19
|
+
expected = [("8086", "5916"), ("10de", "1d10")]
|
|
20
|
+
assert sorted(parse_linux_output(output)) == sorted(expected)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_parse_macos_output():
|
|
24
|
+
output = """
|
|
25
|
+
{
|
|
26
|
+
"SPDisplaysDataType": [
|
|
27
|
+
{
|
|
28
|
+
"spdisplays_vendor": "Intel",
|
|
29
|
+
"spdisplays_device-id": "0x5916"
|
|
30
|
+
},
|
|
31
|
+
{
|
|
32
|
+
"spdisplays_vendor": "NVIDIA (0x10de)",
|
|
33
|
+
"spdisplays_device-id": "0x1c82"
|
|
34
|
+
}
|
|
35
|
+
]
|
|
36
|
+
}
|
|
37
|
+
"""
|
|
38
|
+
expected = [("8086", "5916"), ("10de", "1c82")]
|
|
39
|
+
assert sorted(parse_macos_output(output)) == sorted(expected)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_parse_macos_output_invalid_json():
|
|
43
|
+
output = "{ invalid_json: true }"
|
|
44
|
+
assert parse_macos_output(output) == []
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_empty_outputs():
|
|
48
|
+
assert parse_windows_output("") == []
|
|
49
|
+
assert parse_linux_output("") == []
|
|
50
|
+
assert parse_macos_output("") == []
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Integration test, which connects to the database and checks for some common devices
|
|
2
|
+
|
|
3
|
+
from torchruntime.device_db import get_device_infos, DEVICE_DB_FILE
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_db_file_exists():
|
|
7
|
+
import os
|
|
8
|
+
from torchruntime import device_db
|
|
9
|
+
|
|
10
|
+
db_path = os.path.join(os.path.dirname(device_db.__file__), DEVICE_DB_FILE)
|
|
11
|
+
assert os.path.exists(db_path)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_get_single_device():
|
|
15
|
+
"""Test retrieving a single device."""
|
|
16
|
+
result = get_device_infos([("8086", "56a7")])
|
|
17
|
+
assert len(result) == 1
|
|
18
|
+
assert result[0] == ("8086", "Intel Corporation", "56a7", "DG2 [Arc Xe Graphics]")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def test_get_multiple_devices():
|
|
22
|
+
"""Test retrieving multiple devices."""
|
|
23
|
+
input_ids = [("10de", "2786"), ("10de", "2504")]
|
|
24
|
+
result = get_device_infos(input_ids)
|
|
25
|
+
assert len(result) == 2
|
|
26
|
+
assert ("10de", "NVIDIA Corporation", "2786", "AD104 [GeForce RTX 4070]") in result
|
|
27
|
+
assert ("10de", "NVIDIA Corporation", "2504", "GA106 [GeForce RTX 3060 Lite Hash Rate]") in result
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_get_amd_devices():
|
|
31
|
+
"""Test retrieving AMD devices."""
|
|
32
|
+
input_ids = [("1002", "9495"), ("1002", "747e")]
|
|
33
|
+
result = get_device_infos(input_ids)
|
|
34
|
+
assert len(result) == 2
|
|
35
|
+
assert ("1002", "Advanced Micro Devices, Inc. [AMD/ATI]", "9495", "RV730 [Radeon HD 4600 AGP Series]") in result
|
|
36
|
+
assert ("1002", "Advanced Micro Devices, Inc. [AMD/ATI]", "747e", "Navi 32 [Radeon RX 7700 XT / 7800 XT]") in result
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_get_nonexistent_device():
|
|
40
|
+
"""Test retrieving a device that doesn't exist in the database."""
|
|
41
|
+
result = get_device_infos([("ffff", "ffff")])
|
|
42
|
+
assert len(result) == 0
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_get_mixed_existing_and_nonexistent():
|
|
46
|
+
"""Test retrieving a mix of existing and non-existing devices."""
|
|
47
|
+
input_ids = [("8086", "56a7"), ("ffff", "ffff")] # exists # doesn't exist
|
|
48
|
+
result = get_device_infos(input_ids)
|
|
49
|
+
assert len(result) == 1
|
|
50
|
+
assert result[0] == ("8086", "Intel Corporation", "56a7", "DG2 [Arc Xe Graphics]")
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import pytest
|
|
3
|
+
from unittest.mock import patch
|
|
4
|
+
from torchruntime.installer import get_install_commands, get_pip_commands, run_commands
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def test_empty_args():
|
|
8
|
+
packages = []
|
|
9
|
+
result = get_install_commands("cpu", packages)
|
|
10
|
+
assert result == [["torch", "torchaudio", "torchvision"]]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_cpu_platform():
|
|
14
|
+
packages = ["torch", "torchvision"]
|
|
15
|
+
result = get_install_commands("cpu", packages)
|
|
16
|
+
assert result == [packages]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_cuda_platform():
|
|
20
|
+
packages = ["torch", "torchvision"]
|
|
21
|
+
result = get_install_commands("cu112", packages)
|
|
22
|
+
expected_url = "https://download.pytorch.org/whl/cu112"
|
|
23
|
+
assert result == [packages + ["--index-url", expected_url]]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_rocm_platform():
|
|
27
|
+
packages = ["torch", "torchvision"]
|
|
28
|
+
result = get_install_commands("rocm4.2", packages)
|
|
29
|
+
expected_url = "https://download.pytorch.org/whl/rocm4.2"
|
|
30
|
+
assert result == [packages + ["--index-url", expected_url]]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_xpu_platform_windows_with_torch_only(monkeypatch):
|
|
34
|
+
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
|
|
35
|
+
packages = ["torch"]
|
|
36
|
+
result = get_install_commands("xpu", packages)
|
|
37
|
+
expected_url = "https://download.pytorch.org/whl/test/xpu"
|
|
38
|
+
assert result == [packages + ["--index-url", expected_url]]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_xpu_platform_windows_with_torchvision(monkeypatch, capsys):
|
|
42
|
+
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
|
|
43
|
+
packages = ["torch", "torchvision"]
|
|
44
|
+
result = get_install_commands("xpu", packages)
|
|
45
|
+
expected_url = "https://download.pytorch.org/whl/nightly/xpu"
|
|
46
|
+
assert result == [packages + ["--index-url", expected_url]]
|
|
47
|
+
captured = capsys.readouterr()
|
|
48
|
+
assert "[WARNING]" in captured.out
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_xpu_platform_linux(monkeypatch):
|
|
52
|
+
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
|
|
53
|
+
packages = ["torch", "torchvision"]
|
|
54
|
+
result = get_install_commands("xpu", packages)
|
|
55
|
+
expected_url = "https://download.pytorch.org/whl/test/xpu"
|
|
56
|
+
assert result == [packages + ["--index-url", expected_url]]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_directml_platform():
|
|
60
|
+
packages = ["torch", "torchvision"]
|
|
61
|
+
result = get_install_commands("directml", packages)
|
|
62
|
+
assert result == [["torch-directml"], packages]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_ipex_platform():
|
|
66
|
+
packages = ["torch", "torchvision"]
|
|
67
|
+
result = get_install_commands("ipex", packages)
|
|
68
|
+
assert result == [packages, ["intel-extension-for-pytorch"]]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def test_unsupported_platform():
|
|
72
|
+
packages = ["torch", "torchvision"]
|
|
73
|
+
with pytest.raises(ValueError, match="Unsupported platform: unknown"):
|
|
74
|
+
get_install_commands("unknown", packages)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_get_pip_commands_valid():
|
|
78
|
+
cmds = [["package1"], ["package2", "--upgrade"]]
|
|
79
|
+
expected = [
|
|
80
|
+
[sys.executable, "-m", "pip", "install", "package1"],
|
|
81
|
+
[sys.executable, "-m", "pip", "install", "package2", "--upgrade"],
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
result = get_pip_commands(cmds)
|
|
85
|
+
assert result == expected
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_get_pip_commands_none_input():
|
|
89
|
+
cmds = [["package1"], None]
|
|
90
|
+
with pytest.raises(AssertionError):
|
|
91
|
+
get_pip_commands(cmds)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Test suite for run_commands
|
|
95
|
+
def test_run_commands():
|
|
96
|
+
cmds = [
|
|
97
|
+
[sys.executable, "-m", "pip", "install", "package1"],
|
|
98
|
+
[sys.executable, "-m", "pip", "install", "package2", "--upgrade"],
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
with patch("subprocess.run") as mock_run:
|
|
102
|
+
run_commands(cmds)
|
|
103
|
+
|
|
104
|
+
# Ensure subprocess.run was called for each command
|
|
105
|
+
assert mock_run.call_count == len(cmds)
|
|
106
|
+
|
|
107
|
+
# Check that subprocess.run was called with the correct arguments
|
|
108
|
+
mock_run.assert_any_call(cmds[0])
|
|
109
|
+
mock_run.assert_any_call(cmds[1])
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from torchruntime.platform_detection import get_torch_platform, AMD, NVIDIA, INTEL, os_name, arch, py_version
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def test_no_discrete_gpus_windows(monkeypatch):
|
|
6
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
|
|
7
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
|
|
8
|
+
assert get_torch_platform([]) == "cpu"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def test_no_discrete_gpus_linux(monkeypatch):
|
|
12
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
13
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
14
|
+
assert get_torch_platform([]) == "cpu"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def test_no_discrete_gpus_mac(monkeypatch):
|
|
18
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
|
|
19
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
|
|
20
|
+
assert get_torch_platform([]) == "cpu"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_amd_gpu_windows(monkeypatch):
|
|
24
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
|
|
25
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
|
|
26
|
+
discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Radeon")]
|
|
27
|
+
assert get_torch_platform(discrete_gpu_infos) == "directml"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_amd_gpu_navi3_linux(monkeypatch, capsys):
|
|
31
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
32
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
33
|
+
discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Navi 31")]
|
|
34
|
+
expected = "rocm6.1" if py_version < (3, 9) else "rocm6.2"
|
|
35
|
+
assert get_torch_platform(discrete_gpu_infos) == expected
|
|
36
|
+
if py_version < (3, 9):
|
|
37
|
+
captured = capsys.readouterr()
|
|
38
|
+
assert "Support for Python 3.8 was dropped in ROCm 6.2" in captured.out
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_amd_gpu_navi2_linux(monkeypatch, capsys):
|
|
42
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
43
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
44
|
+
discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Navi 22")]
|
|
45
|
+
expected = "rocm6.1" if py_version < (3, 9) else "rocm6.2"
|
|
46
|
+
assert get_torch_platform(discrete_gpu_infos) == expected
|
|
47
|
+
if py_version < (3, 9):
|
|
48
|
+
captured = capsys.readouterr()
|
|
49
|
+
assert "Support for Python 3.8 was dropped in ROCm 6.2" in captured.out
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_amd_gpu_navi1_linux(monkeypatch):
|
|
53
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
54
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
55
|
+
discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Navi 10")]
|
|
56
|
+
assert get_torch_platform(discrete_gpu_infos) == "rocm5.2"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_amd_gpu_vega1_linux(monkeypatch):
|
|
60
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
61
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
62
|
+
discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Vega 10")]
|
|
63
|
+
assert get_torch_platform(discrete_gpu_infos) == "rocm5.7"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def test_amd_gpu_vega2_linux(monkeypatch):
|
|
67
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
68
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
69
|
+
discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Vega 20")]
|
|
70
|
+
assert get_torch_platform(discrete_gpu_infos) == "rocm5.7"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_amd_gpu_ellesmere_linux(monkeypatch):
|
|
74
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
75
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
76
|
+
discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Ellesmere")]
|
|
77
|
+
assert get_torch_platform(discrete_gpu_infos) == "rocm4.2"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_amd_gpu_unsupported_linux(monkeypatch, capsys):
|
|
81
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
82
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
83
|
+
discrete_gpu_infos = [(AMD, "AMD", 0x1234, "UnknownModel")]
|
|
84
|
+
assert get_torch_platform(discrete_gpu_infos) == "cpu"
|
|
85
|
+
captured = capsys.readouterr()
|
|
86
|
+
assert "[WARNING] Unsupported AMD graphics card" in captured.out
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def test_amd_gpu_mac(monkeypatch):
|
|
90
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
|
|
91
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
|
|
92
|
+
discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Radeon")]
|
|
93
|
+
assert get_torch_platform(discrete_gpu_infos) == "rocm4.2"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_nvidia_gpu_windows(monkeypatch):
|
|
97
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
|
|
98
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
|
|
99
|
+
discrete_gpu_infos = [(NVIDIA, "NVIDIA", 0x1234, "GeForce")]
|
|
100
|
+
assert get_torch_platform(discrete_gpu_infos) == "cu124"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def test_nvidia_gpu_linux(monkeypatch):
|
|
104
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
105
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
106
|
+
discrete_gpu_infos = [(NVIDIA, "NVIDIA", 0x1234, "GeForce")]
|
|
107
|
+
assert get_torch_platform(discrete_gpu_infos) == "cu124"
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def test_nvidia_gpu_mac(monkeypatch):
|
|
111
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
|
|
112
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
|
|
113
|
+
discrete_gpu_infos = [(NVIDIA, "NVIDIA", 0x1234, "GeForce")]
|
|
114
|
+
with pytest.raises(NotImplementedError):
|
|
115
|
+
get_torch_platform(discrete_gpu_infos)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_intel_gpu_windows(monkeypatch):
|
|
119
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
|
|
120
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
|
|
121
|
+
discrete_gpu_infos = [(INTEL, "Intel", 0x1234, "Iris")]
|
|
122
|
+
expected = "directml" if py_version < (3, 9) else "xpu"
|
|
123
|
+
assert get_torch_platform(discrete_gpu_infos) == expected
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_intel_gpu_linux(monkeypatch):
|
|
127
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
128
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
|
|
129
|
+
discrete_gpu_infos = [(INTEL, "Intel", 0x1234, "Iris")]
|
|
130
|
+
expected = "ipex" if py_version < (3, 9) else "xpu"
|
|
131
|
+
assert get_torch_platform(discrete_gpu_infos) == expected
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_intel_gpu_mac(monkeypatch):
|
|
135
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
|
|
136
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
|
|
137
|
+
discrete_gpu_infos = [(INTEL, "Intel", 0x1234, "Iris")]
|
|
138
|
+
with pytest.raises(NotImplementedError):
|
|
139
|
+
get_torch_platform(discrete_gpu_infos)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def test_multiple_gpu_vendors(monkeypatch):
|
|
143
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
|
|
144
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
|
|
145
|
+
discrete_gpu_infos = [
|
|
146
|
+
(AMD, "AMD", 0x1234, "Radeon"),
|
|
147
|
+
(NVIDIA, "NVIDIA", 0x5678, "GeForce"),
|
|
148
|
+
]
|
|
149
|
+
with pytest.raises(NotImplementedError):
|
|
150
|
+
get_torch_platform(discrete_gpu_infos)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def test_unsupported_architecture(monkeypatch):
|
|
154
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
|
|
155
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "sparc")
|
|
156
|
+
with pytest.raises(NotImplementedError):
|
|
157
|
+
get_torch_platform([])
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def test_unrecognized_gpu_vendor(monkeypatch):
|
|
161
|
+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
|
|
162
|
+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
|
|
163
|
+
discrete_gpu_infos = [("9999", "UnknownVendor", 0x1234, "Unknown")]
|
|
164
|
+
assert get_torch_platform(discrete_gpu_infos) == "cpu"
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 easydiffusion
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: torchruntime
|
|
3
|
+
Version: 1.0.2
|
|
4
|
+
Summary: Meant for app developers. A convenient way to install and configure the appropriate version of PyTorch on the user's computer, based on the OS and GPU manufacturer and model number.
|
|
5
|
+
Author-email: cmdr2 <secondary.cmdr2@gmail.com>
|
|
6
|
+
Project-URL: Homepage, https://github.com/easydiffusion/torchruntime
|
|
7
|
+
Project-URL: Bug Tracker, https://github.com/easydiffusion/torchruntime/issues
|
|
8
|
+
Keywords: torch,ai,ml,llm,installer,runtime
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
11
|
+
Classifier: Operating System :: Microsoft :: Windows :: Windows 10
|
|
12
|
+
Classifier: Operating System :: Microsoft :: Windows :: Windows 11
|
|
13
|
+
Classifier: Operating System :: POSIX :: Linux
|
|
14
|
+
Classifier: Operating System :: MacOS
|
|
15
|
+
Requires-Python: >=3.0
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
|
|
19
|
+
# torchruntime
|
|
20
|
+
[](https://discord.com/invite/u9yhsFmEkB)
|
|
21
|
+
|
|
22
|
+
**torchruntime** is a lightweight package for automatically installing the appropriate variant of PyTorch on a user's computer, based on their OS, and GPU manufacturer and GPU model.
|
|
23
|
+
|
|
24
|
+
This package is used by [Easy Diffusion](https://github.com/easydiffusion/easydiffusion), but you're welcome to use it as well. It's useful for developers who make PyTorch-based apps that target users with NVIDIA, AMD and Intel graphics cards (as well as CPU-only usage), on Windows, Mac and Linux.
|
|
25
|
+
|
|
26
|
+
### Why?
|
|
27
|
+
It lets you treat PyTorch as a single dependency (like it should be), and lets you assume that each user will get the most-performant variant of PyTorch suitable for their computer's OS and hardware.
|
|
28
|
+
|
|
29
|
+
It deals with the complexity of the variety of torch builds and configurations required for CUDA, AMD (ROCm, DirectML), Intel (xpu/DirectML/ipex), and CPU-only.
|
|
30
|
+
|
|
31
|
+
**Compatibility table**: [Click here](#compatibility-table) to see the supported graphics cards and operating systems.
|
|
32
|
+
|
|
33
|
+
# Installation
|
|
34
|
+
Supports Windows, Linux, and Mac.
|
|
35
|
+
|
|
36
|
+
`pip install torchruntime`
|
|
37
|
+
|
|
38
|
+
## Usage
|
|
39
|
+
### Step 1. Install the appropriate variant of PyTorch
|
|
40
|
+
*This command should be run on the user's computer, or while creating platform-specific builds:*
|
|
41
|
+
|
|
42
|
+
`python -m torchruntime install`
|
|
43
|
+
|
|
44
|
+
This will install `torch`, `torchvision`, and `torchaudio`, and will decide the variant based on the user's OS, GPU manufacturer and GPU model number. See [customizing packages](#customizing-packages) for more options.
|
|
45
|
+
|
|
46
|
+
### Step 2. Initialize torch
|
|
47
|
+
This should be run inside your program, to initialize the required environment variables (if any) for the variant of torch being used.
|
|
48
|
+
|
|
49
|
+
```py
|
|
50
|
+
import torchruntime
|
|
51
|
+
|
|
52
|
+
torchruntime.init_torch()
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Customizing packages
|
|
56
|
+
By default, `python -m torchruntime install` will install the latest available `torch`, `torchvision` and `torchaudio` suitable on the user's platform.
|
|
57
|
+
|
|
58
|
+
You can customize the packages to install by including their names:
|
|
59
|
+
* For e.g. to install only `torch` and `torchvision`, you can run `python -m torchruntime install torch torchvision`
|
|
60
|
+
* To install specific versions (in pip format), you can run `python -m torchruntime install "torch>2.0" "torchvision==0.20"`
|
|
61
|
+
|
|
62
|
+
**Note:** If you specify package versions, please keep in mind that the version may not be available to *all* the users on *all* the torch platforms. For e.g. a user with Python 3.8 would not be able to install torch 2.5 (or higher), because torch 2.5 dropped support for Python 3.8.
|
|
63
|
+
|
|
64
|
+
So in general, it's better to avoid specifying a version unless it really matters to you (or you know what you're doing). Instead, please allow `torchruntime` to pick the latest-possible version for the user.
|
|
65
|
+
|
|
66
|
+
# Compatibility table
|
|
67
|
+
The list of platforms on which `torchruntime` can install a working variant of PyTorch.
|
|
68
|
+
|
|
69
|
+
**Note:** *This list is based on user feedback (since I don't have all the cards). Please let me know if your card is supported (or not) by opening a pull request or issue or messaging on [Discord](https://discord.com/invite/u9yhsFmEkB) (with supporting logs).*
|
|
70
|
+
|
|
71
|
+
**CPU-only:**
|
|
72
|
+
|
|
73
|
+
| OS | Supported?| Notes |
|
|
74
|
+
|---|---|---|
|
|
75
|
+
| Windows | ✅ Yes | x86_64 |
|
|
76
|
+
| Linux | ✅ Yes | x86_64 and aarch64 |
|
|
77
|
+
| Mac (M1/M2/M3/M4) | ✅ Yes | arm64. `mps` backend |
|
|
78
|
+
| Mac (Intel) | ✅ Yes | x86_64. Stopped after `torch 2.2.2` |
|
|
79
|
+
|
|
80
|
+
**NVIDIA:**
|
|
81
|
+
|
|
82
|
+
| Series | Supported? | OS | Notes |
|
|
83
|
+
|---|---|---|---|
|
|
84
|
+
| 40xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
|
|
85
|
+
| 30xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
|
|
86
|
+
| 20xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
|
|
87
|
+
| 10xx/16xx | ✅ Yes | Win/Linux | Uses CUDA 124. Full-precision required on 16xx series |
|
|
88
|
+
|
|
89
|
+
**AMD:**
|
|
90
|
+
|
|
91
|
+
| Series | Supported? | OS | Notes |
|
|
92
|
+
|---|---|---|---|
|
|
93
|
+
| 7xxx | ✅ Yes | Win/Linux | Navi3/RDNA3 (gfx110x). ROCm 6.2 on Linux. DirectML on Windows |
|
|
94
|
+
| 6xxx | ✅ Yes | Win/Linux | Navi2/RDNA2 (gfx103x). ROCm 6.2 on Linux. DirectML on Windows |
|
|
95
|
+
| 5xxx | ✅ Yes | Win/Linux | Navi1/RDNA1 (gfx101x). Full-precision required. DirectML on Windows. Linux only supports upto ROCm 5.2. Waiting for [this](https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756) for ROCm 6.2 support. |
|
|
96
|
+
| 5xxx on Intel Mac | ❓ Untested (WIP) | Intel Mac | gfx101x. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
|
|
97
|
+
| 4xxxG/Radeon VII | ✅ Yes | Win/Linux | Vega 20 gfx906. Need testers for Windows, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
|
|
98
|
+
| 2xxxG/Radeon RX Vega 56 | ❓ Untested (WIP) | N/A | Vega 10 gfx900. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
|
|
99
|
+
| 5xx/Polaris | ❓ Untested (WIP) | N/A | gfx80x. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
|
|
100
|
+
|
|
101
|
+
**Apple:**
|
|
102
|
+
|
|
103
|
+
| Series | Supported? |Notes |
|
|
104
|
+
|---|---|---|
|
|
105
|
+
| M1/M2/M3/M4 | ✅ Yes | 'mps' backend |
|
|
106
|
+
| AMD on Intel Mac | ❓ Untested (WIP) | Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
|
|
107
|
+
|
|
108
|
+
**Intel:**
|
|
109
|
+
|
|
110
|
+
| Series | Supported? | OS | Notes |
|
|
111
|
+
|---|---|---|---|
|
|
112
|
+
| Arc | ❓ Untested (WIP) | Win/Linux | Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB). Backends: 'xpu' or DirectML or [ipex](https://github.com/intel/intel-extension-for-pytorch) |
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# FAQ
|
|
116
|
+
## Why can't I just run 'pip install torch'?
|
|
117
|
+
`pip install torch` installs the CPU-only version of torch, so it won't utilize your GPU's capabilities.
|
|
118
|
+
|
|
119
|
+
## Why can't I just install torch-for-ROCm directly to support AMD?
|
|
120
|
+
Different models of AMD cards require different LLVM targets, and sometimes different ROCm versions. And ROCm currently doesn't work on Windows, so AMD on Windows is best served (currently) with DirectML.
|
|
121
|
+
|
|
122
|
+
And plenty of AMD cards work with ROCm (even when they aren't in the official list of supported cards). Information about these cards (for e.g. the LLVM target to use) is pretty scattered.
|
|
123
|
+
|
|
124
|
+
`torchruntime` deals with this complexity for your convenience.
|
|
125
|
+
|
|
126
|
+
# Contributing
|
|
127
|
+
📢 I'm looking for contributions in these specific areas:
|
|
128
|
+
- More testing on consumer AMD GPUs.
|
|
129
|
+
- More support for older AMD GPUs. Explore: Compile and host PyTorch wheels and rocm (on GitHub) for older AMD gpus (e.g. 580/590/Polaris) with the required patches.
|
|
130
|
+
- Intel GPUs.
|
|
131
|
+
- Testing on professional AMD GPUs (e.g. the Instinct series).
|
|
132
|
+
- An easy-to-run benchmark script (that people can run to check the level of compatibility on their platform).
|
|
133
|
+
|
|
134
|
+
Please message on the [Discord community](https://discord.com/invite/u9yhsFmEkB) if you have AMD or Intel GPUs, and would like to help with testing or adding support for them! Thanks!
|
|
135
|
+
|
|
136
|
+
# Credits
|
|
137
|
+
* Code contributors on [Easy Diffusion](https://github.com/easydiffusion/easydiffusion).
|
|
138
|
+
* Users on [Easy Diffusion's Discord](https://discord.com/invite/u9yhsFmEkB) who've helped with testing on various GPUs.
|
|
139
|
+
* [PCI Database](https://github.com/pciutils/pciids/) automatically generated from the PCI ID Database at http://pci-ids.ucw.cz
|
|
140
|
+
|
|
141
|
+
# More resources
|
|
142
|
+
* [AMD GPU LLVM Architectures](https://web.archive.org/web/20241228163540/https://llvm.org/docs/AMDGPUUsage.html#processors)
|
|
143
|
+
* [Status of ROCm support for AMD Navi 1](https://github.com/ROCm/ROCm/issues/2527)
|
|
144
|
+
* [Torch support for ROCm 6.2 on AMD Navi 1](https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756)
|
|
145
|
+
* [ROCmLibs-for-gfx1103-AMD780M-APU](https://github.com/likelovewant/ROCmLibs-for-gfx1103-AMD780M-APU)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
torchruntime/__init__.py,sha256=m5-DZXo3-hf9yhDiS8f9TfyQC6ztpE7VHdo1BRQmFE8,65
|
|
2
|
+
torchruntime/__main__.py,sha256=Cf7g3YWVtxqSFSGp6KLeg65JPYOz4QE8atk7gP_SeTM,1449
|
|
3
|
+
torchruntime/configure.py,sha256=MJmxkDBzdyHoGWRmeP9Pgsx4jdySVrTsXk5vnmwLDSA,3983
|
|
4
|
+
torchruntime/consts.py,sha256=-N9FORwsESh8e20xaX3l-Ord05BmNrPFmsdk6iwv0Rg,108
|
|
5
|
+
torchruntime/device_db.py,sha256=VadbEm_kJgNNyhj7Y-LLKdWxZUKFU3-MjDU5xHiBYSo,4154
|
|
6
|
+
torchruntime/gpu_pci_ids.db,sha256=j9LzoFILdTn-149K9MY-SIZk7HUxIXzPd7dl2ybUoZk,221184
|
|
7
|
+
torchruntime/installer.py,sha256=Mfqkv8Eh1XTD5e8ycqtQ3Cjv0Db6NBpBgeZsUciYKJE,3689
|
|
8
|
+
torchruntime/platform_detection.py,sha256=QuBNvS813Dk95KbOiRiaYBcsrkCtI39ftBpIE8n7zsI,5364
|
|
9
|
+
torchruntime/tests/test_configure.py,sha256=emiHac_MnCLq7_V82ActNUgZI37v4ozlENG-jIydGiM,4581
|
|
10
|
+
torchruntime/tests/test_device_db.py,sha256=iAWAbJSJLZS9DCt59SSQVoaKSxoDRnTPrIwCVcKiCfI,1590
|
|
11
|
+
torchruntime/tests/test_device_db_integration.py,sha256=Vwk9z3HG8A8T9j-tWUseWDy_Te9_jjlhco-LOTDI06A,1955
|
|
12
|
+
torchruntime/tests/test_installer.py,sha256=8bwYGmJihdhjfg8YT6QKqjDnO15B2vXAb3DGrrAbkms,3625
|
|
13
|
+
torchruntime/tests/test_platform_detection.py,sha256=-lo7qPuaP7pPKSCxW3a2CkzWE2EmjUESB7jtFsaj8fU,7335
|
|
14
|
+
torchruntime-1.0.2.dist-info/LICENSE,sha256=wrkpT53J7BEDdUpsh9F-0-1j-LNkStyAubQNYet24BA,1070
|
|
15
|
+
torchruntime-1.0.2.dist-info/METADATA,sha256=QS0jG6u_6RW0V48u82y1rEhmMBLO1cuR0zZB92VtSmQ,8533
|
|
16
|
+
torchruntime-1.0.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
17
|
+
torchruntime-1.0.2.dist-info/top_level.txt,sha256=8JeRerOboHkQ_Jw_HiZnctS1d4wGHZa9HZJdd0wWRwQ,13
|
|
18
|
+
torchruntime-1.0.2.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
torchruntime
|