torchruntime 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2 @@
1
+ from .installer import install
2
+ from .configure import init_torch
@@ -0,0 +1,52 @@
1
+ from .installer import install
2
+
3
+
4
+ def print_usage():
5
+ """Print usage information with examples."""
6
+ usage = """
7
+ Usage: python -m torchruntime <command> [arguments]
8
+
9
+ Commands:
10
+ install Install PyTorch packages
11
+ --help Show this help message
12
+
13
+ Examples:
14
+ python -m torchruntime install
15
+ python -m torchruntime install torch==2.2.0 torchvision==0.17.0
16
+ python -m torchruntime install torch>=2.0.0 torchaudio
17
+ python -m torchruntime install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0
18
+
19
+ If no packages are specified, the latest available versions
20
+ of torch, torchaudio and torchvision will be installed.
21
+
22
+ Version specification formats (follows pip format):
23
+ package==2.1.0 Exact version
24
+ package>=2.0.0 Minimum version
25
+ package<=2.2.0 Maximum version
26
+ package~=2.1.0 Compatible release
27
+ package==2.1.* Any 2.1.x version
28
+ package Latest version
29
+ """
30
+ print(usage.strip())
31
+
32
+
33
+ def main():
34
+ import sys
35
+
36
+ if len(sys.argv) < 2 or sys.argv[1] in ["--help", "-h"]:
37
+ print_usage()
38
+ return
39
+
40
+ command = sys.argv[1]
41
+
42
+ if command == "install":
43
+ # Pass all arguments after 'install' to the install function
44
+ package_versions = sys.argv[2:] if len(sys.argv) > 2 else None
45
+ install(package_versions)
46
+ else:
47
+ print(f"Unknown command: {command}")
48
+ print_usage()
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
@@ -0,0 +1,80 @@
1
+ import os
2
+
3
+ from .device_db import get_discrete_gpus
4
+ from .platform_detection import get_torch_platform, os_name
5
+
6
+
7
+ def init_torch():
8
+ discrete_gpu_infos = get_discrete_gpus()
9
+ torch_platform = get_torch_platform(discrete_gpu_infos)
10
+
11
+ if torch_platform.startswith("rocm"):
12
+ check_rocm_permissions()
13
+ set_rocm_env_vars(discrete_gpu_infos, torch_platform)
14
+ elif os_name == "Darwin":
15
+ _set_env_vars({"PYTORCH_ENABLE_MPS_FALLBACK": "1"})
16
+
17
+
18
+ def check_rocm_permissions():
19
+ if not os.access("/dev/kfd", os.W_OK):
20
+ print(
21
+ """#########################################################################
22
+ # No write access to /dev/kfd ! #
23
+ #########################################################################
24
+
25
+ Without this, the ROCm driver will probably not be able to initialize the GPU and torch will use the CPU for rendering.
26
+
27
+ Follow the instructions on this site to configure access to /dev/kfd:
28
+ https://github.com/easydiffusion/easydiffusion/wiki/AMD-on-Linux#access-permissions"""
29
+ )
30
+
31
+
32
+ def set_rocm_env_vars(discrete_gpu_infos, torch_platform):
33
+ device_names = [device_name for *_, device_name in discrete_gpu_infos]
34
+
35
+ env = {}
36
+
37
+ # interesting reading:
38
+ # gfx config from: https://web.archive.org/web/20241228163540/https://llvm.org/docs/AMDGPUUsage.html#processors
39
+ # more info: https://web.archive.org/web/20241209013717/https://discuss.linuxcontainers.org/t/rocm-and-pytorch-on-amd-apu-or-gpu-ai/19743
40
+ # this thread is great for understanding the status of torch support for RDNA 1 (i.e. 5000 series): https://github.com/ROCm/ROCm/issues/2527
41
+ # past settings from: https://github.com/easydiffusion/easydiffusion/blob/20d77a85a1ed766ece0cc4b6a55dca003bce262c/scripts/check_modules.py#L405-L420
42
+
43
+ if any(device_name.startswith("Navi 3") for device_name in device_names):
44
+ print("[INFO] Applying Navi 3x settings")
45
+ env["HSA_OVERRIDE_GFX_VERSION"] = "11.0.0"
46
+ elif any(device_name.startswith("Navi 2") for device_name in device_names):
47
+ print("[INFO] Applying Navi 2x settings")
48
+ env["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"
49
+ elif any(device_name.startswith("Navi 1") for device_name in device_names):
50
+ print("[INFO] Applying Navi 1x settings")
51
+ env["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"
52
+ # env["HSA_ENABLE_SDMA"] = "0" # uncomment this if facing errors like in https://github.com/ROCm/ROCm/issues/2616
53
+ env["FORCE_FULL_PRECISION"] = "yes" # https://github.com/ROCm/ROCm/issues/2527
54
+ # FORCE_FULL_PRECISION won't be necessary once this is fixed (and torch2 wheels are released for ROCm 6.2): https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756
55
+ elif any(device_name.startswith("Vega 2") for device_name in device_names): # Radeon VII etc
56
+ print("[INFO] Applying Vega 20 settings")
57
+ env["HSA_OVERRIDE_GFX_VERSION"] = "9.0.6"
58
+ elif any(device_name.startswith("Vega 1") for device_name in device_names): # Radeon RX Vega 56 etc
59
+ print("[INFO] Applying Vega 10 settings")
60
+ env["HSA_OVERRIDE_GFX_VERSION"] = "9.0.0"
61
+ env["FORCE_FULL_PRECISION"] = "yes"
62
+ elif any(device_name.startswith("Ellesmere") for device_name in device_names): # RX 570, 580, 590, Polaris etc
63
+ print("[INFO] Applying Ellesmere settings")
64
+ env["HSA_OVERRIDE_GFX_VERSION"] = "8.0.3" # https://github.com/ROCm/ROCm/issues/1659
65
+ env["ROC_ENABLE_PRE_VEGA"] = "1"
66
+ else:
67
+ env["ROC_ENABLE_PRE_VEGA"] = "1"
68
+ print(f"[WARNING] Unrecognized AMD graphics card: {device_names}")
69
+ return
70
+
71
+ num_devices = len(device_names)
72
+ env["HIP_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(num_devices))
73
+
74
+ _set_env_vars(env)
75
+
76
+
77
+ def _set_env_vars(env):
78
+ for k, v in env.items():
79
+ print(f"[INFO] Setting env variable {k}={v}")
80
+ os.environ[k] = v
torchruntime/consts.py ADDED
@@ -0,0 +1,5 @@
1
+ CONTACT_LINK = "https://github.com/easydiffusion/torchruntime"
2
+
3
+ AMD = "1002"
4
+ NVIDIA = "10de"
5
+ INTEL = "8086"
@@ -0,0 +1,138 @@
1
+ import os
2
+ import re
3
+ import json
4
+ import sqlite3
5
+ import platform
6
+ import subprocess
7
+
8
+ DEVICE_DB_FILE = "gpu_pci_ids.db" # this file will only include AMD, NVIDIA and Discrete Intel GPUs
9
+
10
+ os_name = platform.system()
11
+
12
+
13
+ def get_windows_output():
14
+ try:
15
+ command = ["powershell", "-Command", "Get-WmiObject Win32_VideoController | ForEach-Object { $_.PNPDeviceID }"]
16
+ return subprocess.check_output(command, text=True, stderr=subprocess.DEVNULL)
17
+ except subprocess.CalledProcessError:
18
+ return ""
19
+
20
+
21
+ def get_linux_output():
22
+ try:
23
+ return subprocess.check_output(["lspci", "-nn"], text=True)
24
+ except FileNotFoundError:
25
+ return ""
26
+
27
+
28
+ def get_macos_output():
29
+ try:
30
+ return subprocess.check_output(["system_profiler", "-json", "SPDisplaysDataType"], text=True)
31
+ except subprocess.CalledProcessError:
32
+ return ""
33
+
34
+
35
+ def parse_windows_output(output):
36
+ pci_ids = set()
37
+ for line in output.splitlines():
38
+ match = re.search(r"VEN_(\w+)&DEV_(\w+)", line, re.IGNORECASE)
39
+ if match:
40
+ vendor_id = match.group(1).lower()
41
+ device_id = match.group(2).lower()
42
+ pci_ids.add((vendor_id, device_id))
43
+ return list(pci_ids)
44
+
45
+
46
+ def parse_linux_output(output):
47
+ pci_ids = set()
48
+ for line in output.splitlines():
49
+ match = re.search(r"\[(\w+):(\w+)\]", line)
50
+ if match:
51
+ vendor_id = match.group(1).lower()
52
+ device_id = match.group(2).lower()
53
+ pci_ids.add((vendor_id, device_id))
54
+ return list(pci_ids)
55
+
56
+
57
+ def parse_macos_output(output):
58
+ pci_ids = set()
59
+ try:
60
+ data = json.loads(output)
61
+ displays = data.get("SPDisplaysDataType", [])
62
+ for display in displays:
63
+ vendor_raw = display.get("spdisplays_vendor", "")
64
+ device_id_raw = display.get("spdisplays_device-id", "")
65
+ if device_id_raw and vendor_raw:
66
+ device_id = device_id_raw.replace("0x", "").lower()
67
+ if "Intel" in vendor_raw:
68
+ vendor_id = "8086"
69
+ else:
70
+ match = re.search(r"\((0x\w+)\)", vendor_raw)
71
+ if match:
72
+ vendor_id = match.group(1).replace("0x", "").lower()
73
+ else:
74
+ continue
75
+ pci_ids.add((vendor_id, device_id))
76
+ except json.JSONDecodeError:
77
+ pass
78
+ return list(pci_ids)
79
+
80
+
81
+ def get_pci_ids():
82
+ if os_name == "Windows":
83
+ output = get_windows_output()
84
+ return parse_windows_output(output)
85
+ elif os_name == "Linux":
86
+ output = get_linux_output()
87
+ return parse_linux_output(output)
88
+ elif os_name == "Darwin": # macOS
89
+ output = get_macos_output()
90
+ return parse_macos_output(output)
91
+ else:
92
+ return []
93
+
94
+
95
+ def get_device_infos(pci_ids):
96
+ """
97
+ Reads the given SQLite database file and queries the `pci_ids` table
98
+ for matching vendor_id and device_id. Returns a list of tuples containing
99
+ (vendor_id, vendor_name, device_id, device_name).
100
+
101
+ Args:
102
+ db_file_name (str): Path to the SQLite database file.
103
+ pci_ids (list of tuples): List of (vendor_id, device_id) pairs to match.
104
+
105
+ Returns:
106
+ list of tuples: List of (vendor_id, vendor_name, device_id, device_name).
107
+ """
108
+ result = []
109
+
110
+ # Establish connection to the database
111
+ db_path = os.path.join(os.path.dirname(__file__), DEVICE_DB_FILE)
112
+ conn = sqlite3.connect(db_path)
113
+ cursor = conn.cursor()
114
+
115
+ try:
116
+ # Create a query to retrieve matching rows
117
+ query = """
118
+ SELECT vendor_id, vendor_name, device_id, device_name
119
+ FROM pci_ids
120
+ WHERE vendor_id = ? AND device_id = ?
121
+ """
122
+
123
+ # Execute query for each (vendor_id, device_id) in pci_ids
124
+ for vendor_id, device_id in pci_ids:
125
+ cursor.execute(query, (vendor_id, device_id))
126
+ rows = cursor.fetchall()
127
+ result.extend(rows)
128
+
129
+ finally:
130
+ # Close the database connection
131
+ conn.close()
132
+
133
+ return result
134
+
135
+
136
+ def get_discrete_gpus():
137
+ pci_ids = get_pci_ids()
138
+ return get_device_infos(pci_ids)
Binary file
@@ -0,0 +1,98 @@
1
+ import re
2
+ import sys
3
+ import platform
4
+ import subprocess
5
+
6
+ from .consts import CONTACT_LINK
7
+ from .device_db import get_discrete_gpus
8
+ from .platform_detection import get_torch_platform
9
+
10
+ os_name = platform.system()
11
+
12
+ PIP_PREFIX = [sys.executable, "-m", "pip", "install"]
13
+ CUDA_REGEX = re.compile(r"^cu[\d+]")
14
+
15
+
16
+ def get_install_commands(torch_platform, packages):
17
+ """
18
+ Generates pip installation commands for PyTorch and related packages based on the specified platform.
19
+
20
+ Args:
21
+ torch_platform (str): Target platform for PyTorch. Must be one of:
22
+ - "cpu"
23
+ - "cuXXX" (e.g., "cu112", "cu126")
24
+ - "rocmXXX" (e.g., "rocm4.2", "rocm6.2")
25
+ - "xpu"
26
+ - "directml"
27
+ - "ipex"
28
+ packages (list of str): List of package names (and optionally versions in pip format). Examples:
29
+ - ["torch", "torchvision"]
30
+ - ["torch>=2.0", "torchaudio==0.16.0"]
31
+
32
+ Returns:
33
+ list of list of str: Each sublist contains a pip install command (excluding the `pip install` prefix).
34
+ Examples:
35
+ - [["torch", "--index-url", "https://foo.com/whl"]]
36
+ - [["torch-directml"], ["torch", "torchvision"]]
37
+
38
+ Raises:
39
+ ValueError: If an unsupported platform is provided.
40
+
41
+ Notes:
42
+ - For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds.
43
+ - For "directml", the "torch-directml" package is returned as part of the installation commands.
44
+ - For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands.
45
+ """
46
+ if not packages:
47
+ packages = ["torch", "torchaudio", "torchvision"]
48
+
49
+ if torch_platform == "cpu":
50
+ return [packages]
51
+
52
+ if CUDA_REGEX.match(torch_platform) or torch_platform.startswith("rocm"):
53
+ index_url = f"https://download.pytorch.org/whl/{torch_platform}"
54
+ return [packages + ["--index-url", index_url]]
55
+
56
+ if torch_platform == "xpu":
57
+ if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages):
58
+ print(
59
+ f"[WARNING] The preview build of 'xpu' on Windows currently only supports torch, not torchvision/torchaudio. "
60
+ f"torch-installer will instead use the nightly build, to get the 'xpu' version of torchaudio and torchvision as well. "
61
+ f"Please contact torch-installer if this is no longer accurate: {CONTACT_LINK}"
62
+ )
63
+ index_url = f"https://download.pytorch.org/whl/nightly/{torch_platform}"
64
+ else:
65
+ index_url = f"https://download.pytorch.org/whl/test/{torch_platform}"
66
+
67
+ return [packages + ["--index-url", index_url]]
68
+
69
+ if torch_platform == "directml":
70
+ return [["torch-directml"], packages]
71
+
72
+ if torch_platform == "ipex":
73
+ return [packages, ["intel-extension-for-pytorch"]]
74
+
75
+ raise ValueError(f"Unsupported platform: {torch_platform}")
76
+
77
+
78
+ def get_pip_commands(cmds):
79
+ assert not any(cmd is None for cmd in cmds)
80
+ return [PIP_PREFIX + cmd for cmd in cmds]
81
+
82
+
83
+ def run_commands(cmds):
84
+ for cmd in cmds:
85
+ print("> ", cmd)
86
+ subprocess.run(cmd)
87
+
88
+
89
+ def install(packages=[]):
90
+ """
91
+ packages: a list of strings with package names (and optionally their versions in pip-format). e.g. ["torch", "torchvision"] or ["torch>=2.0", "torchaudio==0.16.0"]. Defaults to ["torch", "torchvision", "torchaudio"].
92
+ """
93
+
94
+ discrete_gpu_infos = get_discrete_gpus()
95
+ torch_platform = get_torch_platform(discrete_gpu_infos)
96
+ cmds = get_install_commands(torch_platform, packages)
97
+ cmds = get_pip_commands(cmds)
98
+ run_commands(cmds)
@@ -0,0 +1,115 @@
1
+ import sys
2
+ import platform
3
+
4
+ from .consts import AMD, INTEL, NVIDIA, CONTACT_LINK
5
+
6
+ os_name = platform.system()
7
+ arch = platform.machine().lower()
8
+ py_version = sys.version_info
9
+
10
+
11
+ def get_torch_platform(discrete_gpu_infos):
12
+ """
13
+ Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information.
14
+
15
+ Args:
16
+ discrete_gpu_infos (list of tuples): A list where each tuple represents a GPU. Each tuple contains:
17
+ - vendor_id (int): The vendor ID of the GPU (e.g., NVIDIA, AMD, INTEL constants).
18
+ - other details (e.g., model, memory, etc., not used directly in this function).
19
+
20
+ Returns:
21
+ str: A string representing the platform to use. Possible values:
22
+ - "cpu": No discrete GPUs or unsupported configuration.
23
+ - "cuXXX": NVIDIA CUDA version (e.g., "cu124").
24
+ - "rocmXXX": AMD ROCm version (e.g., "rocm6.2").
25
+ - "directml": DirectML for AMD or Intel GPUs on Windows.
26
+ - "ipex": Intel Extension for PyTorch for Linux.
27
+ - "xpu": Intel's backend for PyTorch.
28
+
29
+ Raises:
30
+ NotImplementedError: For unsupported architectures, OS-GPU combinations, or multiple GPU vendors.
31
+ Warning: Outputs warnings for deprecated Python versions or fallback configurations.
32
+ """
33
+
34
+ VALID_ARCHS = {
35
+ "Windows": {"amd64"},
36
+ "Linux": {"x86_64", "aarch64"},
37
+ "Darwin": {"x86_64", "arm64"},
38
+ }
39
+
40
+ if arch not in VALID_ARCHS[os_name]:
41
+ raise NotImplementedError(
42
+ f"torch is not currently available for {os_name} on {arch} architecture! If this is no longer true, please contact torch-installer at {CONTACT_LINK}"
43
+ )
44
+
45
+ if len(discrete_gpu_infos) == 0:
46
+ return "cpu"
47
+
48
+ vendor_ids = set(vendor_id for vendor_id, *_ in discrete_gpu_infos)
49
+
50
+ if len(vendor_ids) > 1:
51
+ device_names = list(vendor_name + " " + device_name for _, vendor_name, _, device_name in discrete_gpu_infos)
52
+ raise NotImplementedError(
53
+ f"torch-installer does not currently support multiple graphics card manufacturers on the same computer: {device_names}! Please contact torch-installer at {CONTACT_LINK} with details about your hardware."
54
+ )
55
+
56
+ vendor_id = vendor_ids.pop()
57
+ if vendor_id == AMD:
58
+ if os_name == "Windows":
59
+ return "directml"
60
+ elif os_name == "Linux":
61
+ device_names = set(device_name for *_, device_name in discrete_gpu_infos)
62
+ if any(
63
+ device_name.startswith("Navi 3") or device_name.startswith("Navi 2") for device_name in device_names
64
+ ):
65
+ if py_version < (3, 9):
66
+ print(
67
+ "[WARNING] Support for Python 3.8 was dropped in ROCm 6.2. torch-installer will default to using ROCm 6.1 instead, but consider switching to a newer Python version to use the latest ROCm!"
68
+ )
69
+ return "rocm6.1"
70
+ return "rocm6.2"
71
+ elif any(device_name.startswith("Navi 1") for device_name in device_names):
72
+ return "rocm5.2"
73
+ elif any(
74
+ device_name.startswith("Vega 1") or device_name.startswith("Vega 2") for device_name in device_names
75
+ ):
76
+ return "rocm5.7"
77
+ elif any(device_name.startswith("Ellesmere") for device_name in device_names):
78
+ return "rocm4.2"
79
+
80
+ print(
81
+ f"[WARNING] Unsupported AMD graphics card: {device_names}. If this is a recent graphics card (less than 8 years old), please contact torch-installer at {CONTACT_LINK} with details about your hardware."
82
+ )
83
+ return "cpu"
84
+ elif os_name == "Darwin":
85
+ return "rocm4.2"
86
+ elif vendor_id == NVIDIA:
87
+ if os_name in ("Windows", "Linux"):
88
+ return "cu124"
89
+ elif os_name == "Darwin":
90
+ raise NotImplementedError(
91
+ f"torch-installer does not currently support NVIDIA graphics cards on Macs! Please contact torch-installer at {CONTACT_LINK}"
92
+ )
93
+ elif vendor_id == INTEL:
94
+ if os_name == "Windows":
95
+ if py_version < (3, 9):
96
+ print(
97
+ "[WARNING] Support for Python 3.8 was dropped in torch 2.5, which supports a higher-performance 'xpu' backend for Intel. torch-installer will default to using 'directml' instead, but consider switching to a newer Python version to use the latest 'xpu' backend for Intel!"
98
+ )
99
+ return "directml"
100
+ return "xpu"
101
+ elif os_name == "Linux":
102
+ if py_version < (3, 9):
103
+ print(
104
+ "[WARNING] Support for Python 3.8 was dropped in torch 2.5, which supports a higher-performance 'xpu' backend for Intel. torch-installer will default to using 'intel-extension-for-pytorch' instead, but consider switching to a newer Python version to use the latest 'xpu' backend for Intel!"
105
+ )
106
+ return "ipex"
107
+ return "xpu"
108
+ else:
109
+ raise NotImplementedError(
110
+ f"torch-installer does not currently support Intel graphics cards on Macs! Please contact torch-installer at {CONTACT_LINK}"
111
+ )
112
+
113
+ print(f"Unrecognized vendor: {discrete_gpu_infos}")
114
+
115
+ return "cpu"
@@ -0,0 +1,137 @@
1
+ import os
2
+ import pytest
3
+ from torchruntime.configure import set_rocm_env_vars
4
+ from torchruntime.consts import AMD
5
+
6
+
7
+ def create_gpu_info(device_id, device_name):
8
+ return (AMD, "Advanced Micro Devices, Inc. [AMD/ATI]", device_id, device_name)
9
+
10
+
11
+ @pytest.fixture(autouse=True)
12
+ def clean_env():
13
+ # Remove relevant environment variables before each test
14
+ env_vars = [
15
+ "HSA_OVERRIDE_GFX_VERSION",
16
+ "HIP_VISIBLE_DEVICES",
17
+ "ROC_ENABLE_PRE_VEGA",
18
+ "HSA_ENABLE_SDMA",
19
+ "FORCE_FULL_PRECISION",
20
+ ]
21
+
22
+ # Store original values
23
+ original_values = {}
24
+ for var in env_vars:
25
+ if var in os.environ:
26
+ original_values[var] = os.environ[var]
27
+ del os.environ[var]
28
+
29
+ yield
30
+
31
+ # Restore original values and remove any new ones
32
+ for var in env_vars:
33
+ if var in os.environ and var not in original_values:
34
+ del os.environ[var]
35
+ elif var in original_values:
36
+ os.environ[var] = original_values[var]
37
+
38
+
39
+ def test_navi_3_settings():
40
+ gpus = [create_gpu_info("123", "Navi 31 XTX")]
41
+ set_rocm_env_vars(gpus, "rocm6.2")
42
+
43
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "11.0.0"
44
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
45
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
46
+ assert "FORCE_FULL_PRECISION" not in os.environ
47
+
48
+
49
+ def test_navi_2_settings():
50
+ gpus = [create_gpu_info("123", "Navi 21 XTX")]
51
+ set_rocm_env_vars(gpus, "rocm6.2")
52
+
53
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "10.3.0"
54
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
55
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
56
+ assert "FORCE_FULL_PRECISION" not in os.environ
57
+
58
+
59
+ def test_navi_1_settings():
60
+ gpus = [create_gpu_info("123", "Navi 14")]
61
+ set_rocm_env_vars(gpus, "rocm6.2")
62
+
63
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "10.3.0"
64
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
65
+ assert os.environ.get("FORCE_FULL_PRECISION") == "yes"
66
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
67
+
68
+
69
+ def test_vega_2_settings():
70
+ gpus = [create_gpu_info("123", "Vega 20 Radeon VII")]
71
+ set_rocm_env_vars(gpus, "rocm6.2")
72
+
73
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "9.0.6"
74
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
75
+ assert "FORCE_FULL_PRECISION" not in os.environ
76
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
77
+
78
+
79
+ def test_vega_1_settings():
80
+ gpus = [create_gpu_info("123", "Vega 10")]
81
+ set_rocm_env_vars(gpus, "rocm6.2")
82
+
83
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "9.0.0"
84
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
85
+ assert os.environ.get("FORCE_FULL_PRECISION") == "yes"
86
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
87
+
88
+
89
+ def test_ellesmere_settings():
90
+ gpus = [create_gpu_info("123", "Ellesmere RX 580")]
91
+ set_rocm_env_vars(gpus, "rocm6.2")
92
+
93
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "8.0.3"
94
+ assert os.environ.get("ROC_ENABLE_PRE_VEGA") == "1"
95
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
96
+ assert "FORCE_FULL_PRECISION" not in os.environ
97
+
98
+
99
+ def test_unknown_gpu_settings():
100
+ gpus = [create_gpu_info("123", "Unknown GPU")]
101
+ set_rocm_env_vars(gpus, "rocm6.2")
102
+
103
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
104
+ assert "HIP_VISIBLE_DEVICES" not in os.environ
105
+ assert "HSA_OVERRIDE_GFX_VERSION" not in os.environ
106
+ assert "FORCE_FULL_PRECISION" not in os.environ
107
+
108
+
109
+ def test_multiple_gpus_same_type():
110
+ gpus = [create_gpu_info("123", "Navi 31 XTX"), create_gpu_info("124", "Navi 31 XT")]
111
+ set_rocm_env_vars(gpus, "rocm6.2")
112
+
113
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "11.0.0"
114
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0,1"
115
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
116
+ assert "FORCE_FULL_PRECISION" not in os.environ
117
+
118
+
119
+ def test_multiple_gpus_mixed_types():
120
+ gpus = [create_gpu_info("123", "Navi 31 XTX"), create_gpu_info("124", "Navi 21 XT")]
121
+ set_rocm_env_vars(gpus, "rocm6.2")
122
+
123
+ # Should use Navi 3 settings since at least one GPU is Navi 3
124
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "11.0.0"
125
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0,1"
126
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
127
+ assert "FORCE_FULL_PRECISION" not in os.environ
128
+
129
+
130
+ def test_empty_gpu_list():
131
+ gpus = []
132
+ set_rocm_env_vars(gpus, "rocm6.2")
133
+
134
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
135
+ assert "HIP_VISIBLE_DEVICES" not in os.environ
136
+ assert "HSA_OVERRIDE_GFX_VERSION" not in os.environ
137
+ assert "FORCE_FULL_PRECISION" not in os.environ
@@ -0,0 +1,50 @@
1
+ from torchruntime.device_db import parse_windows_output, parse_linux_output, parse_macos_output
2
+
3
+
4
+ def test_parse_windows_output():
5
+ output = """
6
+ PCI\\VEN_8086&DEV_591B&SUBSYS_2212103C&REV_04
7
+ PCI\\VEN_10DE&DEV_1C82&SUBSYS_37131462&REV_A1
8
+ PCI\\VEN_10DE&DEV_2504&SUBSYS_881D1043&REV_A1\\4&22AF55FA&0&0008
9
+ """
10
+ expected = [("8086", "591b"), ("10de", "1c82"), ("10de", ("2504"))]
11
+ assert sorted(parse_windows_output(output)) == sorted(expected)
12
+
13
+
14
+ def test_parse_linux_output():
15
+ output = """
16
+ 00:02.0 VGA compatible controller: Intel Corporation HD Graphics 620 (rev 02) [8086:5916]
17
+ 01:00.0 3D controller: NVIDIA Corporation GP108M [GeForce MX150] (rev a1) [10de:1d10]
18
+ """
19
+ expected = [("8086", "5916"), ("10de", "1d10")]
20
+ assert sorted(parse_linux_output(output)) == sorted(expected)
21
+
22
+
23
+ def test_parse_macos_output():
24
+ output = """
25
+ {
26
+ "SPDisplaysDataType": [
27
+ {
28
+ "spdisplays_vendor": "Intel",
29
+ "spdisplays_device-id": "0x5916"
30
+ },
31
+ {
32
+ "spdisplays_vendor": "NVIDIA (0x10de)",
33
+ "spdisplays_device-id": "0x1c82"
34
+ }
35
+ ]
36
+ }
37
+ """
38
+ expected = [("8086", "5916"), ("10de", "1c82")]
39
+ assert sorted(parse_macos_output(output)) == sorted(expected)
40
+
41
+
42
+ def test_parse_macos_output_invalid_json():
43
+ output = "{ invalid_json: true }"
44
+ assert parse_macos_output(output) == []
45
+
46
+
47
+ def test_empty_outputs():
48
+ assert parse_windows_output("") == []
49
+ assert parse_linux_output("") == []
50
+ assert parse_macos_output("") == []
@@ -0,0 +1,50 @@
1
+ # Integration test, which connects to the database and checks for some common devices
2
+
3
+ from torchruntime.device_db import get_device_infos, DEVICE_DB_FILE
4
+
5
+
6
+ def test_db_file_exists():
7
+ import os
8
+ from torchruntime import device_db
9
+
10
+ db_path = os.path.join(os.path.dirname(device_db.__file__), DEVICE_DB_FILE)
11
+ assert os.path.exists(db_path)
12
+
13
+
14
+ def test_get_single_device():
15
+ """Test retrieving a single device."""
16
+ result = get_device_infos([("8086", "56a7")])
17
+ assert len(result) == 1
18
+ assert result[0] == ("8086", "Intel Corporation", "56a7", "DG2 [Arc Xe Graphics]")
19
+
20
+
21
+ def test_get_multiple_devices():
22
+ """Test retrieving multiple devices."""
23
+ input_ids = [("10de", "2786"), ("10de", "2504")]
24
+ result = get_device_infos(input_ids)
25
+ assert len(result) == 2
26
+ assert ("10de", "NVIDIA Corporation", "2786", "AD104 [GeForce RTX 4070]") in result
27
+ assert ("10de", "NVIDIA Corporation", "2504", "GA106 [GeForce RTX 3060 Lite Hash Rate]") in result
28
+
29
+
30
+ def test_get_amd_devices():
31
+ """Test retrieving AMD devices."""
32
+ input_ids = [("1002", "9495"), ("1002", "747e")]
33
+ result = get_device_infos(input_ids)
34
+ assert len(result) == 2
35
+ assert ("1002", "Advanced Micro Devices, Inc. [AMD/ATI]", "9495", "RV730 [Radeon HD 4600 AGP Series]") in result
36
+ assert ("1002", "Advanced Micro Devices, Inc. [AMD/ATI]", "747e", "Navi 32 [Radeon RX 7700 XT / 7800 XT]") in result
37
+
38
+
39
+ def test_get_nonexistent_device():
40
+ """Test retrieving a device that doesn't exist in the database."""
41
+ result = get_device_infos([("ffff", "ffff")])
42
+ assert len(result) == 0
43
+
44
+
45
+ def test_get_mixed_existing_and_nonexistent():
46
+ """Test retrieving a mix of existing and non-existing devices."""
47
+ input_ids = [("8086", "56a7"), ("ffff", "ffff")] # exists # doesn't exist
48
+ result = get_device_infos(input_ids)
49
+ assert len(result) == 1
50
+ assert result[0] == ("8086", "Intel Corporation", "56a7", "DG2 [Arc Xe Graphics]")
@@ -0,0 +1,109 @@
1
+ import sys
2
+ import pytest
3
+ from unittest.mock import patch
4
+ from torchruntime.installer import get_install_commands, get_pip_commands, run_commands
5
+
6
+
7
+ def test_empty_args():
8
+ packages = []
9
+ result = get_install_commands("cpu", packages)
10
+ assert result == [["torch", "torchaudio", "torchvision"]]
11
+
12
+
13
+ def test_cpu_platform():
14
+ packages = ["torch", "torchvision"]
15
+ result = get_install_commands("cpu", packages)
16
+ assert result == [packages]
17
+
18
+
19
+ def test_cuda_platform():
20
+ packages = ["torch", "torchvision"]
21
+ result = get_install_commands("cu112", packages)
22
+ expected_url = "https://download.pytorch.org/whl/cu112"
23
+ assert result == [packages + ["--index-url", expected_url]]
24
+
25
+
26
+ def test_rocm_platform():
27
+ packages = ["torch", "torchvision"]
28
+ result = get_install_commands("rocm4.2", packages)
29
+ expected_url = "https://download.pytorch.org/whl/rocm4.2"
30
+ assert result == [packages + ["--index-url", expected_url]]
31
+
32
+
33
+ def test_xpu_platform_windows_with_torch_only(monkeypatch):
34
+ monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
35
+ packages = ["torch"]
36
+ result = get_install_commands("xpu", packages)
37
+ expected_url = "https://download.pytorch.org/whl/test/xpu"
38
+ assert result == [packages + ["--index-url", expected_url]]
39
+
40
+
41
+ def test_xpu_platform_windows_with_torchvision(monkeypatch, capsys):
42
+ monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
43
+ packages = ["torch", "torchvision"]
44
+ result = get_install_commands("xpu", packages)
45
+ expected_url = "https://download.pytorch.org/whl/nightly/xpu"
46
+ assert result == [packages + ["--index-url", expected_url]]
47
+ captured = capsys.readouterr()
48
+ assert "[WARNING]" in captured.out
49
+
50
+
51
+ def test_xpu_platform_linux(monkeypatch):
52
+ monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
53
+ packages = ["torch", "torchvision"]
54
+ result = get_install_commands("xpu", packages)
55
+ expected_url = "https://download.pytorch.org/whl/test/xpu"
56
+ assert result == [packages + ["--index-url", expected_url]]
57
+
58
+
59
+ def test_directml_platform():
60
+ packages = ["torch", "torchvision"]
61
+ result = get_install_commands("directml", packages)
62
+ assert result == [["torch-directml"], packages]
63
+
64
+
65
+ def test_ipex_platform():
66
+ packages = ["torch", "torchvision"]
67
+ result = get_install_commands("ipex", packages)
68
+ assert result == [packages, ["intel-extension-for-pytorch"]]
69
+
70
+
71
+ def test_unsupported_platform():
72
+ packages = ["torch", "torchvision"]
73
+ with pytest.raises(ValueError, match="Unsupported platform: unknown"):
74
+ get_install_commands("unknown", packages)
75
+
76
+
77
+ def test_get_pip_commands_valid():
78
+ cmds = [["package1"], ["package2", "--upgrade"]]
79
+ expected = [
80
+ [sys.executable, "-m", "pip", "install", "package1"],
81
+ [sys.executable, "-m", "pip", "install", "package2", "--upgrade"],
82
+ ]
83
+
84
+ result = get_pip_commands(cmds)
85
+ assert result == expected
86
+
87
+
88
+ def test_get_pip_commands_none_input():
89
+ cmds = [["package1"], None]
90
+ with pytest.raises(AssertionError):
91
+ get_pip_commands(cmds)
92
+
93
+
94
+ # Test suite for run_commands
95
+ def test_run_commands():
96
+ cmds = [
97
+ [sys.executable, "-m", "pip", "install", "package1"],
98
+ [sys.executable, "-m", "pip", "install", "package2", "--upgrade"],
99
+ ]
100
+
101
+ with patch("subprocess.run") as mock_run:
102
+ run_commands(cmds)
103
+
104
+ # Ensure subprocess.run was called for each command
105
+ assert mock_run.call_count == len(cmds)
106
+
107
+ # Check that subprocess.run was called with the correct arguments
108
+ mock_run.assert_any_call(cmds[0])
109
+ mock_run.assert_any_call(cmds[1])
@@ -0,0 +1,164 @@
1
+ import pytest
2
+ from torchruntime.platform_detection import get_torch_platform, AMD, NVIDIA, INTEL, os_name, arch, py_version
3
+
4
+
5
+ def test_no_discrete_gpus_windows(monkeypatch):
6
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
7
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
8
+ assert get_torch_platform([]) == "cpu"
9
+
10
+
11
+ def test_no_discrete_gpus_linux(monkeypatch):
12
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
13
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
14
+ assert get_torch_platform([]) == "cpu"
15
+
16
+
17
+ def test_no_discrete_gpus_mac(monkeypatch):
18
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
19
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
20
+ assert get_torch_platform([]) == "cpu"
21
+
22
+
23
+ def test_amd_gpu_windows(monkeypatch):
24
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
25
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
26
+ discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Radeon")]
27
+ assert get_torch_platform(discrete_gpu_infos) == "directml"
28
+
29
+
30
+ def test_amd_gpu_navi3_linux(monkeypatch, capsys):
31
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
32
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
33
+ discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Navi 31")]
34
+ expected = "rocm6.1" if py_version < (3, 9) else "rocm6.2"
35
+ assert get_torch_platform(discrete_gpu_infos) == expected
36
+ if py_version < (3, 9):
37
+ captured = capsys.readouterr()
38
+ assert "Support for Python 3.8 was dropped in ROCm 6.2" in captured.out
39
+
40
+
41
+ def test_amd_gpu_navi2_linux(monkeypatch, capsys):
42
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
43
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
44
+ discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Navi 22")]
45
+ expected = "rocm6.1" if py_version < (3, 9) else "rocm6.2"
46
+ assert get_torch_platform(discrete_gpu_infos) == expected
47
+ if py_version < (3, 9):
48
+ captured = capsys.readouterr()
49
+ assert "Support for Python 3.8 was dropped in ROCm 6.2" in captured.out
50
+
51
+
52
+ def test_amd_gpu_navi1_linux(monkeypatch):
53
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
54
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
55
+ discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Navi 10")]
56
+ assert get_torch_platform(discrete_gpu_infos) == "rocm5.2"
57
+
58
+
59
+ def test_amd_gpu_vega1_linux(monkeypatch):
60
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
61
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
62
+ discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Vega 10")]
63
+ assert get_torch_platform(discrete_gpu_infos) == "rocm5.7"
64
+
65
+
66
+ def test_amd_gpu_vega2_linux(monkeypatch):
67
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
68
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
69
+ discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Vega 20")]
70
+ assert get_torch_platform(discrete_gpu_infos) == "rocm5.7"
71
+
72
+
73
+ def test_amd_gpu_ellesmere_linux(monkeypatch):
74
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
75
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
76
+ discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Ellesmere")]
77
+ assert get_torch_platform(discrete_gpu_infos) == "rocm4.2"
78
+
79
+
80
+ def test_amd_gpu_unsupported_linux(monkeypatch, capsys):
81
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
82
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
83
+ discrete_gpu_infos = [(AMD, "AMD", 0x1234, "UnknownModel")]
84
+ assert get_torch_platform(discrete_gpu_infos) == "cpu"
85
+ captured = capsys.readouterr()
86
+ assert "[WARNING] Unsupported AMD graphics card" in captured.out
87
+
88
+
89
+ def test_amd_gpu_mac(monkeypatch):
90
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
91
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
92
+ discrete_gpu_infos = [(AMD, "AMD", 0x1234, "Radeon")]
93
+ assert get_torch_platform(discrete_gpu_infos) == "rocm4.2"
94
+
95
+
96
+ def test_nvidia_gpu_windows(monkeypatch):
97
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
98
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
99
+ discrete_gpu_infos = [(NVIDIA, "NVIDIA", 0x1234, "GeForce")]
100
+ assert get_torch_platform(discrete_gpu_infos) == "cu124"
101
+
102
+
103
+ def test_nvidia_gpu_linux(monkeypatch):
104
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
105
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
106
+ discrete_gpu_infos = [(NVIDIA, "NVIDIA", 0x1234, "GeForce")]
107
+ assert get_torch_platform(discrete_gpu_infos) == "cu124"
108
+
109
+
110
+ def test_nvidia_gpu_mac(monkeypatch):
111
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
112
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
113
+ discrete_gpu_infos = [(NVIDIA, "NVIDIA", 0x1234, "GeForce")]
114
+ with pytest.raises(NotImplementedError):
115
+ get_torch_platform(discrete_gpu_infos)
116
+
117
+
118
+ def test_intel_gpu_windows(monkeypatch):
119
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
120
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
121
+ discrete_gpu_infos = [(INTEL, "Intel", 0x1234, "Iris")]
122
+ expected = "directml" if py_version < (3, 9) else "xpu"
123
+ assert get_torch_platform(discrete_gpu_infos) == expected
124
+
125
+
126
+ def test_intel_gpu_linux(monkeypatch):
127
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
128
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
129
+ discrete_gpu_infos = [(INTEL, "Intel", 0x1234, "Iris")]
130
+ expected = "ipex" if py_version < (3, 9) else "xpu"
131
+ assert get_torch_platform(discrete_gpu_infos) == expected
132
+
133
+
134
+ def test_intel_gpu_mac(monkeypatch):
135
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Darwin")
136
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "arm64")
137
+ discrete_gpu_infos = [(INTEL, "Intel", 0x1234, "Iris")]
138
+ with pytest.raises(NotImplementedError):
139
+ get_torch_platform(discrete_gpu_infos)
140
+
141
+
142
+ def test_multiple_gpu_vendors(monkeypatch):
143
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
144
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
145
+ discrete_gpu_infos = [
146
+ (AMD, "AMD", 0x1234, "Radeon"),
147
+ (NVIDIA, "NVIDIA", 0x5678, "GeForce"),
148
+ ]
149
+ with pytest.raises(NotImplementedError):
150
+ get_torch_platform(discrete_gpu_infos)
151
+
152
+
153
+ def test_unsupported_architecture(monkeypatch):
154
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
155
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "sparc")
156
+ with pytest.raises(NotImplementedError):
157
+ get_torch_platform([])
158
+
159
+
160
+ def test_unrecognized_gpu_vendor(monkeypatch):
161
+ monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
162
+ monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
163
+ discrete_gpu_infos = [("9999", "UnknownVendor", 0x1234, "Unknown")]
164
+ assert get_torch_platform(discrete_gpu_infos) == "cpu"
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 easydiffusion
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,145 @@
1
+ Metadata-Version: 2.2
2
+ Name: torchruntime
3
+ Version: 1.0.2
4
+ Summary: Meant for app developers. A convenient way to install and configure the appropriate version of PyTorch on the user's computer, based on the OS and GPU manufacturer and model number.
5
+ Author-email: cmdr2 <secondary.cmdr2@gmail.com>
6
+ Project-URL: Homepage, https://github.com/easydiffusion/torchruntime
7
+ Project-URL: Bug Tracker, https://github.com/easydiffusion/torchruntime/issues
8
+ Keywords: torch,ai,ml,llm,installer,runtime
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: Microsoft :: Windows :: Windows 10
12
+ Classifier: Operating System :: Microsoft :: Windows :: Windows 11
13
+ Classifier: Operating System :: POSIX :: Linux
14
+ Classifier: Operating System :: MacOS
15
+ Requires-Python: >=3.0
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+
19
+ # torchruntime
20
+ [![Discord Server](https://img.shields.io/discord/1014774730907209781?label=Discord)](https://discord.com/invite/u9yhsFmEkB)
21
+
22
+ **torchruntime** is a lightweight package for automatically installing the appropriate variant of PyTorch on a user's computer, based on their OS, and GPU manufacturer and GPU model.
23
+
24
+ This package is used by [Easy Diffusion](https://github.com/easydiffusion/easydiffusion), but you're welcome to use it as well. It's useful for developers who make PyTorch-based apps that target users with NVIDIA, AMD and Intel graphics cards (as well as CPU-only usage), on Windows, Mac and Linux.
25
+
26
+ ### Why?
27
+ It lets you treat PyTorch as a single dependency (like it should be), and lets you assume that each user will get the most-performant variant of PyTorch suitable for their computer's OS and hardware.
28
+
29
+ It deals with the complexity of the variety of torch builds and configurations required for CUDA, AMD (ROCm, DirectML), Intel (xpu/DirectML/ipex), and CPU-only.
30
+
31
+ **Compatibility table**: [Click here](#compatibility-table) to see the supported graphics cards and operating systems.
32
+
33
+ # Installation
34
+ Supports Windows, Linux, and Mac.
35
+
36
+ `pip install torchruntime`
37
+
38
+ ## Usage
39
+ ### Step 1. Install the appropriate variant of PyTorch
40
+ *This command should be run on the user's computer, or while creating platform-specific builds:*
41
+
42
+ `python -m torchruntime install`
43
+
44
+ This will install `torch`, `torchvision`, and `torchaudio`, and will decide the variant based on the user's OS, GPU manufacturer and GPU model number. See [customizing packages](#customizing-packages) for more options.
45
+
46
+ ### Step 2. Initialize torch
47
+ This should be run inside your program, to initialize the required environment variables (if any) for the variant of torch being used.
48
+
49
+ ```py
50
+ import torchruntime
51
+
52
+ torchruntime.init_torch()
53
+ ```
54
+
55
+ ## Customizing packages
56
+ By default, `python -m torchruntime install` will install the latest available `torch`, `torchvision` and `torchaudio` suitable on the user's platform.
57
+
58
+ You can customize the packages to install by including their names:
59
+ * For e.g. to install only `torch` and `torchvision`, you can run `python -m torchruntime install torch torchvision`
60
+ * To install specific versions (in pip format), you can run `python -m torchruntime install "torch>2.0" "torchvision==0.20"`
61
+
62
+ **Note:** If you specify package versions, please keep in mind that the version may not be available to *all* the users on *all* the torch platforms. For e.g. a user with Python 3.8 would not be able to install torch 2.5 (or higher), because torch 2.5 dropped support for Python 3.8.
63
+
64
+ So in general, it's better to avoid specifying a version unless it really matters to you (or you know what you're doing). Instead, please allow `torchruntime` to pick the latest-possible version for the user.
65
+
66
+ # Compatibility table
67
+ The list of platforms on which `torchruntime` can install a working variant of PyTorch.
68
+
69
+ **Note:** *This list is based on user feedback (since I don't have all the cards). Please let me know if your card is supported (or not) by opening a pull request or issue or messaging on [Discord](https://discord.com/invite/u9yhsFmEkB) (with supporting logs).*
70
+
71
+ **CPU-only:**
72
+
73
+ | OS | Supported?| Notes |
74
+ |---|---|---|
75
+ | Windows | ✅ Yes | x86_64 |
76
+ | Linux | ✅ Yes | x86_64 and aarch64 |
77
+ | Mac (M1/M2/M3/M4) | ✅ Yes | arm64. `mps` backend |
78
+ | Mac (Intel) | ✅ Yes | x86_64. Stopped after `torch 2.2.2` |
79
+
80
+ **NVIDIA:**
81
+
82
+ | Series | Supported? | OS | Notes |
83
+ |---|---|---|---|
84
+ | 40xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
85
+ | 30xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
86
+ | 20xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
87
+ | 10xx/16xx | ✅ Yes | Win/Linux | Uses CUDA 124. Full-precision required on 16xx series |
88
+
89
+ **AMD:**
90
+
91
+ | Series | Supported? | OS | Notes |
92
+ |---|---|---|---|
93
+ | 7xxx | ✅ Yes | Win/Linux | Navi3/RDNA3 (gfx110x). ROCm 6.2 on Linux. DirectML on Windows |
94
+ | 6xxx | ✅ Yes | Win/Linux | Navi2/RDNA2 (gfx103x). ROCm 6.2 on Linux. DirectML on Windows |
95
+ | 5xxx | ✅ Yes | Win/Linux | Navi1/RDNA1 (gfx101x). Full-precision required. DirectML on Windows. Linux only supports upto ROCm 5.2. Waiting for [this](https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756) for ROCm 6.2 support. |
96
+ | 5xxx on Intel Mac | ❓ Untested (WIP) | Intel Mac | gfx101x. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
97
+ | 4xxxG/Radeon VII | ✅ Yes | Win/Linux | Vega 20 gfx906. Need testers for Windows, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
98
+ | 2xxxG/Radeon RX Vega 56 | ❓ Untested (WIP) | N/A | Vega 10 gfx900. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
99
+ | 5xx/Polaris | ❓ Untested (WIP) | N/A | gfx80x. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
100
+
101
+ **Apple:**
102
+
103
+ | Series | Supported? |Notes |
104
+ |---|---|---|
105
+ | M1/M2/M3/M4 | ✅ Yes | 'mps' backend |
106
+ | AMD on Intel Mac | ❓ Untested (WIP) | Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
107
+
108
+ **Intel:**
109
+
110
+ | Series | Supported? | OS | Notes |
111
+ |---|---|---|---|
112
+ | Arc | ❓ Untested (WIP) | Win/Linux | Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB). Backends: 'xpu' or DirectML or [ipex](https://github.com/intel/intel-extension-for-pytorch) |
113
+
114
+
115
+ # FAQ
116
+ ## Why can't I just run 'pip install torch'?
117
+ `pip install torch` installs the CPU-only version of torch, so it won't utilize your GPU's capabilities.
118
+
119
+ ## Why can't I just install torch-for-ROCm directly to support AMD?
120
+ Different models of AMD cards require different LLVM targets, and sometimes different ROCm versions. And ROCm currently doesn't work on Windows, so AMD on Windows is best served (currently) with DirectML.
121
+
122
+ And plenty of AMD cards work with ROCm (even when they aren't in the official list of supported cards). Information about these cards (for e.g. the LLVM target to use) is pretty scattered.
123
+
124
+ `torchruntime` deals with this complexity for your convenience.
125
+
126
+ # Contributing
127
+ 📢 I'm looking for contributions in these specific areas:
128
+ - More testing on consumer AMD GPUs.
129
+ - More support for older AMD GPUs. Explore: Compile and host PyTorch wheels and rocm (on GitHub) for older AMD gpus (e.g. 580/590/Polaris) with the required patches.
130
+ - Intel GPUs.
131
+ - Testing on professional AMD GPUs (e.g. the Instinct series).
132
+ - An easy-to-run benchmark script (that people can run to check the level of compatibility on their platform).
133
+
134
+ Please message on the [Discord community](https://discord.com/invite/u9yhsFmEkB) if you have AMD or Intel GPUs, and would like to help with testing or adding support for them! Thanks!
135
+
136
+ # Credits
137
+ * Code contributors on [Easy Diffusion](https://github.com/easydiffusion/easydiffusion).
138
+ * Users on [Easy Diffusion's Discord](https://discord.com/invite/u9yhsFmEkB) who've helped with testing on various GPUs.
139
+ * [PCI Database](https://github.com/pciutils/pciids/) automatically generated from the PCI ID Database at http://pci-ids.ucw.cz
140
+
141
+ # More resources
142
+ * [AMD GPU LLVM Architectures](https://web.archive.org/web/20241228163540/https://llvm.org/docs/AMDGPUUsage.html#processors)
143
+ * [Status of ROCm support for AMD Navi 1](https://github.com/ROCm/ROCm/issues/2527)
144
+ * [Torch support for ROCm 6.2 on AMD Navi 1](https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756)
145
+ * [ROCmLibs-for-gfx1103-AMD780M-APU](https://github.com/likelovewant/ROCmLibs-for-gfx1103-AMD780M-APU)
@@ -0,0 +1,18 @@
1
+ torchruntime/__init__.py,sha256=m5-DZXo3-hf9yhDiS8f9TfyQC6ztpE7VHdo1BRQmFE8,65
2
+ torchruntime/__main__.py,sha256=Cf7g3YWVtxqSFSGp6KLeg65JPYOz4QE8atk7gP_SeTM,1449
3
+ torchruntime/configure.py,sha256=MJmxkDBzdyHoGWRmeP9Pgsx4jdySVrTsXk5vnmwLDSA,3983
4
+ torchruntime/consts.py,sha256=-N9FORwsESh8e20xaX3l-Ord05BmNrPFmsdk6iwv0Rg,108
5
+ torchruntime/device_db.py,sha256=VadbEm_kJgNNyhj7Y-LLKdWxZUKFU3-MjDU5xHiBYSo,4154
6
+ torchruntime/gpu_pci_ids.db,sha256=j9LzoFILdTn-149K9MY-SIZk7HUxIXzPd7dl2ybUoZk,221184
7
+ torchruntime/installer.py,sha256=Mfqkv8Eh1XTD5e8ycqtQ3Cjv0Db6NBpBgeZsUciYKJE,3689
8
+ torchruntime/platform_detection.py,sha256=QuBNvS813Dk95KbOiRiaYBcsrkCtI39ftBpIE8n7zsI,5364
9
+ torchruntime/tests/test_configure.py,sha256=emiHac_MnCLq7_V82ActNUgZI37v4ozlENG-jIydGiM,4581
10
+ torchruntime/tests/test_device_db.py,sha256=iAWAbJSJLZS9DCt59SSQVoaKSxoDRnTPrIwCVcKiCfI,1590
11
+ torchruntime/tests/test_device_db_integration.py,sha256=Vwk9z3HG8A8T9j-tWUseWDy_Te9_jjlhco-LOTDI06A,1955
12
+ torchruntime/tests/test_installer.py,sha256=8bwYGmJihdhjfg8YT6QKqjDnO15B2vXAb3DGrrAbkms,3625
13
+ torchruntime/tests/test_platform_detection.py,sha256=-lo7qPuaP7pPKSCxW3a2CkzWE2EmjUESB7jtFsaj8fU,7335
14
+ torchruntime-1.0.2.dist-info/LICENSE,sha256=wrkpT53J7BEDdUpsh9F-0-1j-LNkStyAubQNYet24BA,1070
15
+ torchruntime-1.0.2.dist-info/METADATA,sha256=QS0jG6u_6RW0V48u82y1rEhmMBLO1cuR0zZB92VtSmQ,8533
16
+ torchruntime-1.0.2.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
17
+ torchruntime-1.0.2.dist-info/top_level.txt,sha256=8JeRerOboHkQ_Jw_HiZnctS1d4wGHZa9HZJdd0wWRwQ,13
18
+ torchruntime-1.0.2.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.8.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ torchruntime