torchruntime 1.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. torchruntime-1.0.2/LICENSE +21 -0
  2. torchruntime-1.0.2/MANIFEST.in +1 -0
  3. torchruntime-1.0.2/PKG-INFO +145 -0
  4. torchruntime-1.0.2/README.md +127 -0
  5. torchruntime-1.0.2/pyproject.toml +33 -0
  6. torchruntime-1.0.2/setup.cfg +4 -0
  7. torchruntime-1.0.2/setup.py +5 -0
  8. torchruntime-1.0.2/tests/test_configure.py +137 -0
  9. torchruntime-1.0.2/tests/test_device_db.py +50 -0
  10. torchruntime-1.0.2/tests/test_device_db_integration.py +50 -0
  11. torchruntime-1.0.2/tests/test_installer.py +109 -0
  12. torchruntime-1.0.2/tests/test_platform_detection.py +164 -0
  13. torchruntime-1.0.2/torchruntime/__init__.py +2 -0
  14. torchruntime-1.0.2/torchruntime/__main__.py +52 -0
  15. torchruntime-1.0.2/torchruntime/configure.py +80 -0
  16. torchruntime-1.0.2/torchruntime/consts.py +5 -0
  17. torchruntime-1.0.2/torchruntime/device_db.py +138 -0
  18. torchruntime-1.0.2/torchruntime/gpu_pci_ids.db +0 -0
  19. torchruntime-1.0.2/torchruntime/installer.py +98 -0
  20. torchruntime-1.0.2/torchruntime/platform_detection.py +115 -0
  21. torchruntime-1.0.2/torchruntime/tests/test_configure.py +137 -0
  22. torchruntime-1.0.2/torchruntime/tests/test_device_db.py +50 -0
  23. torchruntime-1.0.2/torchruntime/tests/test_device_db_integration.py +50 -0
  24. torchruntime-1.0.2/torchruntime/tests/test_installer.py +109 -0
  25. torchruntime-1.0.2/torchruntime/tests/test_platform_detection.py +164 -0
  26. torchruntime-1.0.2/torchruntime.egg-info/PKG-INFO +145 -0
  27. torchruntime-1.0.2/torchruntime.egg-info/SOURCES.txt +27 -0
  28. torchruntime-1.0.2/torchruntime.egg-info/dependency_links.txt +1 -0
  29. torchruntime-1.0.2/torchruntime.egg-info/top_level.txt +1 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 easydiffusion
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ include torchruntime/gpu_pci_ids.db
@@ -0,0 +1,145 @@
1
+ Metadata-Version: 2.2
2
+ Name: torchruntime
3
+ Version: 1.0.2
4
+ Summary: Meant for app developers. A convenient way to install and configure the appropriate version of PyTorch on the user's computer, based on the OS and GPU manufacturer and model number.
5
+ Author-email: cmdr2 <secondary.cmdr2@gmail.com>
6
+ Project-URL: Homepage, https://github.com/easydiffusion/torchruntime
7
+ Project-URL: Bug Tracker, https://github.com/easydiffusion/torchruntime/issues
8
+ Keywords: torch,ai,ml,llm,installer,runtime
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Operating System :: Microsoft :: Windows :: Windows 10
12
+ Classifier: Operating System :: Microsoft :: Windows :: Windows 11
13
+ Classifier: Operating System :: POSIX :: Linux
14
+ Classifier: Operating System :: MacOS
15
+ Requires-Python: >=3.0
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+
19
+ # torchruntime
20
+ [![Discord Server](https://img.shields.io/discord/1014774730907209781?label=Discord)](https://discord.com/invite/u9yhsFmEkB)
21
+
22
+ **torchruntime** is a lightweight package for automatically installing the appropriate variant of PyTorch on a user's computer, based on their OS, and GPU manufacturer and GPU model.
23
+
24
+ This package is used by [Easy Diffusion](https://github.com/easydiffusion/easydiffusion), but you're welcome to use it as well. It's useful for developers who make PyTorch-based apps that target users with NVIDIA, AMD and Intel graphics cards (as well as CPU-only usage), on Windows, Mac and Linux.
25
+
26
+ ### Why?
27
+ It lets you treat PyTorch as a single dependency (like it should be), and lets you assume that each user will get the most-performant variant of PyTorch suitable for their computer's OS and hardware.
28
+
29
+ It deals with the complexity of the variety of torch builds and configurations required for CUDA, AMD (ROCm, DirectML), Intel (xpu/DirectML/ipex), and CPU-only.
30
+
31
+ **Compatibility table**: [Click here](#compatibility-table) to see the supported graphics cards and operating systems.
32
+
33
+ # Installation
34
+ Supports Windows, Linux, and Mac.
35
+
36
+ `pip install torchruntime`
37
+
38
+ ## Usage
39
+ ### Step 1. Install the appropriate variant of PyTorch
40
+ *This command should be run on the user's computer, or while creating platform-specific builds:*
41
+
42
+ `python -m torchruntime install`
43
+
44
+ This will install `torch`, `torchvision`, and `torchaudio`, and will decide the variant based on the user's OS, GPU manufacturer and GPU model number. See [customizing packages](#customizing-packages) for more options.
45
+
46
+ ### Step 2. Initialize torch
47
+ This should be run inside your program, to initialize the required environment variables (if any) for the variant of torch being used.
48
+
49
+ ```py
50
+ import torchruntime
51
+
52
+ torchruntime.init_torch()
53
+ ```
54
+
55
+ ## Customizing packages
56
+ By default, `python -m torchruntime install` will install the latest available `torch`, `torchvision` and `torchaudio` suitable on the user's platform.
57
+
58
+ You can customize the packages to install by including their names:
59
+ * For e.g. to install only `torch` and `torchvision`, you can run `python -m torchruntime install torch torchvision`
60
+ * To install specific versions (in pip format), you can run `python -m torchruntime install "torch>2.0" "torchvision==0.20"`
61
+
62
+ **Note:** If you specify package versions, please keep in mind that the version may not be available to *all* the users on *all* the torch platforms. For e.g. a user with Python 3.8 would not be able to install torch 2.5 (or higher), because torch 2.5 dropped support for Python 3.8.
63
+
64
+ So in general, it's better to avoid specifying a version unless it really matters to you (or you know what you're doing). Instead, please allow `torchruntime` to pick the latest-possible version for the user.
65
+
66
+ # Compatibility table
67
+ The list of platforms on which `torchruntime` can install a working variant of PyTorch.
68
+
69
+ **Note:** *This list is based on user feedback (since I don't have all the cards). Please let me know if your card is supported (or not) by opening a pull request or issue or messaging on [Discord](https://discord.com/invite/u9yhsFmEkB) (with supporting logs).*
70
+
71
+ **CPU-only:**
72
+
73
+ | OS | Supported?| Notes |
74
+ |---|---|---|
75
+ | Windows | ✅ Yes | x86_64 |
76
+ | Linux | ✅ Yes | x86_64 and aarch64 |
77
+ | Mac (M1/M2/M3/M4) | ✅ Yes | arm64. `mps` backend |
78
+ | Mac (Intel) | ✅ Yes | x86_64. Stopped after `torch 2.2.2` |
79
+
80
+ **NVIDIA:**
81
+
82
+ | Series | Supported? | OS | Notes |
83
+ |---|---|---|---|
84
+ | 40xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
85
+ | 30xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
86
+ | 20xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
87
+ | 10xx/16xx | ✅ Yes | Win/Linux | Uses CUDA 124. Full-precision required on 16xx series |
88
+
89
+ **AMD:**
90
+
91
+ | Series | Supported? | OS | Notes |
92
+ |---|---|---|---|
93
+ | 7xxx | ✅ Yes | Win/Linux | Navi3/RDNA3 (gfx110x). ROCm 6.2 on Linux. DirectML on Windows |
94
+ | 6xxx | ✅ Yes | Win/Linux | Navi2/RDNA2 (gfx103x). ROCm 6.2 on Linux. DirectML on Windows |
95
+ | 5xxx | ✅ Yes | Win/Linux | Navi1/RDNA1 (gfx101x). Full-precision required. DirectML on Windows. Linux only supports upto ROCm 5.2. Waiting for [this](https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756) for ROCm 6.2 support. |
96
+ | 5xxx on Intel Mac | ❓ Untested (WIP) | Intel Mac | gfx101x. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
97
+ | 4xxxG/Radeon VII | ✅ Yes | Win/Linux | Vega 20 gfx906. Need testers for Windows, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
98
+ | 2xxxG/Radeon RX Vega 56 | ❓ Untested (WIP) | N/A | Vega 10 gfx900. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
99
+ | 5xx/Polaris | ❓ Untested (WIP) | N/A | gfx80x. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
100
+
101
+ **Apple:**
102
+
103
+ | Series | Supported? |Notes |
104
+ |---|---|---|
105
+ | M1/M2/M3/M4 | ✅ Yes | 'mps' backend |
106
+ | AMD on Intel Mac | ❓ Untested (WIP) | Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
107
+
108
+ **Intel:**
109
+
110
+ | Series | Supported? | OS | Notes |
111
+ |---|---|---|---|
112
+ | Arc | ❓ Untested (WIP) | Win/Linux | Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB). Backends: 'xpu' or DirectML or [ipex](https://github.com/intel/intel-extension-for-pytorch) |
113
+
114
+
115
+ # FAQ
116
+ ## Why can't I just run 'pip install torch'?
117
+ `pip install torch` installs the CPU-only version of torch, so it won't utilize your GPU's capabilities.
118
+
119
+ ## Why can't I just install torch-for-ROCm directly to support AMD?
120
+ Different models of AMD cards require different LLVM targets, and sometimes different ROCm versions. And ROCm currently doesn't work on Windows, so AMD on Windows is best served (currently) with DirectML.
121
+
122
+ And plenty of AMD cards work with ROCm (even when they aren't in the official list of supported cards). Information about these cards (for e.g. the LLVM target to use) is pretty scattered.
123
+
124
+ `torchruntime` deals with this complexity for your convenience.
125
+
126
+ # Contributing
127
+ 📢 I'm looking for contributions in these specific areas:
128
+ - More testing on consumer AMD GPUs.
129
+ - More support for older AMD GPUs. Explore: Compile and host PyTorch wheels and rocm (on GitHub) for older AMD gpus (e.g. 580/590/Polaris) with the required patches.
130
+ - Intel GPUs.
131
+ - Testing on professional AMD GPUs (e.g. the Instinct series).
132
+ - An easy-to-run benchmark script (that people can run to check the level of compatibility on their platform).
133
+
134
+ Please message on the [Discord community](https://discord.com/invite/u9yhsFmEkB) if you have AMD or Intel GPUs, and would like to help with testing or adding support for them! Thanks!
135
+
136
+ # Credits
137
+ * Code contributors on [Easy Diffusion](https://github.com/easydiffusion/easydiffusion).
138
+ * Users on [Easy Diffusion's Discord](https://discord.com/invite/u9yhsFmEkB) who've helped with testing on various GPUs.
139
+ * [PCI Database](https://github.com/pciutils/pciids/) automatically generated from the PCI ID Database at http://pci-ids.ucw.cz
140
+
141
+ # More resources
142
+ * [AMD GPU LLVM Architectures](https://web.archive.org/web/20241228163540/https://llvm.org/docs/AMDGPUUsage.html#processors)
143
+ * [Status of ROCm support for AMD Navi 1](https://github.com/ROCm/ROCm/issues/2527)
144
+ * [Torch support for ROCm 6.2 on AMD Navi 1](https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756)
145
+ * [ROCmLibs-for-gfx1103-AMD780M-APU](https://github.com/likelovewant/ROCmLibs-for-gfx1103-AMD780M-APU)
@@ -0,0 +1,127 @@
1
+ # torchruntime
2
+ [![Discord Server](https://img.shields.io/discord/1014774730907209781?label=Discord)](https://discord.com/invite/u9yhsFmEkB)
3
+
4
+ **torchruntime** is a lightweight package for automatically installing the appropriate variant of PyTorch on a user's computer, based on their OS, and GPU manufacturer and GPU model.
5
+
6
+ This package is used by [Easy Diffusion](https://github.com/easydiffusion/easydiffusion), but you're welcome to use it as well. It's useful for developers who make PyTorch-based apps that target users with NVIDIA, AMD and Intel graphics cards (as well as CPU-only usage), on Windows, Mac and Linux.
7
+
8
+ ### Why?
9
+ It lets you treat PyTorch as a single dependency (like it should be), and lets you assume that each user will get the most-performant variant of PyTorch suitable for their computer's OS and hardware.
10
+
11
+ It deals with the complexity of the variety of torch builds and configurations required for CUDA, AMD (ROCm, DirectML), Intel (xpu/DirectML/ipex), and CPU-only.
12
+
13
+ **Compatibility table**: [Click here](#compatibility-table) to see the supported graphics cards and operating systems.
14
+
15
+ # Installation
16
+ Supports Windows, Linux, and Mac.
17
+
18
+ `pip install torchruntime`
19
+
20
+ ## Usage
21
+ ### Step 1. Install the appropriate variant of PyTorch
22
+ *This command should be run on the user's computer, or while creating platform-specific builds:*
23
+
24
+ `python -m torchruntime install`
25
+
26
+ This will install `torch`, `torchvision`, and `torchaudio`, and will decide the variant based on the user's OS, GPU manufacturer and GPU model number. See [customizing packages](#customizing-packages) for more options.
27
+
28
+ ### Step 2. Initialize torch
29
+ This should be run inside your program, to initialize the required environment variables (if any) for the variant of torch being used.
30
+
31
+ ```py
32
+ import torchruntime
33
+
34
+ torchruntime.init_torch()
35
+ ```
36
+
37
+ ## Customizing packages
38
+ By default, `python -m torchruntime install` will install the latest available `torch`, `torchvision` and `torchaudio` suitable on the user's platform.
39
+
40
+ You can customize the packages to install by including their names:
41
+ * For e.g. to install only `torch` and `torchvision`, you can run `python -m torchruntime install torch torchvision`
42
+ * To install specific versions (in pip format), you can run `python -m torchruntime install "torch>2.0" "torchvision==0.20"`
43
+
44
+ **Note:** If you specify package versions, please keep in mind that the version may not be available to *all* the users on *all* the torch platforms. For e.g. a user with Python 3.8 would not be able to install torch 2.5 (or higher), because torch 2.5 dropped support for Python 3.8.
45
+
46
+ So in general, it's better to avoid specifying a version unless it really matters to you (or you know what you're doing). Instead, please allow `torchruntime` to pick the latest-possible version for the user.
47
+
48
+ # Compatibility table
49
+ The list of platforms on which `torchruntime` can install a working variant of PyTorch.
50
+
51
+ **Note:** *This list is based on user feedback (since I don't have all the cards). Please let me know if your card is supported (or not) by opening a pull request or issue or messaging on [Discord](https://discord.com/invite/u9yhsFmEkB) (with supporting logs).*
52
+
53
+ **CPU-only:**
54
+
55
+ | OS | Supported?| Notes |
56
+ |---|---|---|
57
+ | Windows | ✅ Yes | x86_64 |
58
+ | Linux | ✅ Yes | x86_64 and aarch64 |
59
+ | Mac (M1/M2/M3/M4) | ✅ Yes | arm64. `mps` backend |
60
+ | Mac (Intel) | ✅ Yes | x86_64. Stopped after `torch 2.2.2` |
61
+
62
+ **NVIDIA:**
63
+
64
+ | Series | Supported? | OS | Notes |
65
+ |---|---|---|---|
66
+ | 40xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
67
+ | 30xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
68
+ | 20xx | ✅ Yes | Win/Linux | Uses CUDA 124 |
69
+ | 10xx/16xx | ✅ Yes | Win/Linux | Uses CUDA 124. Full-precision required on 16xx series |
70
+
71
+ **AMD:**
72
+
73
+ | Series | Supported? | OS | Notes |
74
+ |---|---|---|---|
75
+ | 7xxx | ✅ Yes | Win/Linux | Navi3/RDNA3 (gfx110x). ROCm 6.2 on Linux. DirectML on Windows |
76
+ | 6xxx | ✅ Yes | Win/Linux | Navi2/RDNA2 (gfx103x). ROCm 6.2 on Linux. DirectML on Windows |
77
+ | 5xxx | ✅ Yes | Win/Linux | Navi1/RDNA1 (gfx101x). Full-precision required. DirectML on Windows. Linux only supports upto ROCm 5.2. Waiting for [this](https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756) for ROCm 6.2 support. |
78
+ | 5xxx on Intel Mac | ❓ Untested (WIP) | Intel Mac | gfx101x. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
79
+ | 4xxxG/Radeon VII | ✅ Yes | Win/Linux | Vega 20 gfx906. Need testers for Windows, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
80
+ | 2xxxG/Radeon RX Vega 56 | ❓ Untested (WIP) | N/A | Vega 10 gfx900. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
81
+ | 5xx/Polaris | ❓ Untested (WIP) | N/A | gfx80x. Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
82
+
83
+ **Apple:**
84
+
85
+ | Series | Supported? |Notes |
86
+ |---|---|---|
87
+ | M1/M2/M3/M4 | ✅ Yes | 'mps' backend |
88
+ | AMD on Intel Mac | ❓ Untested (WIP) | Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB) |
89
+
90
+ **Intel:**
91
+
92
+ | Series | Supported? | OS | Notes |
93
+ |---|---|---|---|
94
+ | Arc | ❓ Untested (WIP) | Win/Linux | Implemented but need testers, please message on [Discord](https://discord.com/invite/u9yhsFmEkB). Backends: 'xpu' or DirectML or [ipex](https://github.com/intel/intel-extension-for-pytorch) |
95
+
96
+
97
+ # FAQ
98
+ ## Why can't I just run 'pip install torch'?
99
+ `pip install torch` installs the CPU-only version of torch, so it won't utilize your GPU's capabilities.
100
+
101
+ ## Why can't I just install torch-for-ROCm directly to support AMD?
102
+ Different models of AMD cards require different LLVM targets, and sometimes different ROCm versions. And ROCm currently doesn't work on Windows, so AMD on Windows is best served (currently) with DirectML.
103
+
104
+ And plenty of AMD cards work with ROCm (even when they aren't in the official list of supported cards). Information about these cards (for e.g. the LLVM target to use) is pretty scattered.
105
+
106
+ `torchruntime` deals with this complexity for your convenience.
107
+
108
+ # Contributing
109
+ 📢 I'm looking for contributions in these specific areas:
110
+ - More testing on consumer AMD GPUs.
111
+ - More support for older AMD GPUs. Explore: Compile and host PyTorch wheels and rocm (on GitHub) for older AMD gpus (e.g. 580/590/Polaris) with the required patches.
112
+ - Intel GPUs.
113
+ - Testing on professional AMD GPUs (e.g. the Instinct series).
114
+ - An easy-to-run benchmark script (that people can run to check the level of compatibility on their platform).
115
+
116
+ Please message on the [Discord community](https://discord.com/invite/u9yhsFmEkB) if you have AMD or Intel GPUs, and would like to help with testing or adding support for them! Thanks!
117
+
118
+ # Credits
119
+ * Code contributors on [Easy Diffusion](https://github.com/easydiffusion/easydiffusion).
120
+ * Users on [Easy Diffusion's Discord](https://discord.com/invite/u9yhsFmEkB) who've helped with testing on various GPUs.
121
+ * [PCI Database](https://github.com/pciutils/pciids/) automatically generated from the PCI ID Database at http://pci-ids.ucw.cz
122
+
123
+ # More resources
124
+ * [AMD GPU LLVM Architectures](https://web.archive.org/web/20241228163540/https://llvm.org/docs/AMDGPUUsage.html#processors)
125
+ * [Status of ROCm support for AMD Navi 1](https://github.com/ROCm/ROCm/issues/2527)
126
+ * [Torch support for ROCm 6.2 on AMD Navi 1](https://github.com/pytorch/pytorch/issues/132570#issuecomment-2313071756)
127
+ * [ROCmLibs-for-gfx1103-AMD780M-APU](https://github.com/likelovewant/ROCmLibs-for-gfx1103-AMD780M-APU)
@@ -0,0 +1,33 @@
1
+ [build-system]
2
+ requires = [ "setuptools",]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "torchruntime"
7
+ version = "1.0.2"
8
+ description = "Meant for app developers. A convenient way to install and configure the appropriate version of PyTorch on the user's computer, based on the OS and GPU manufacturer and model number."
9
+ readme = "README.md"
10
+ requires-python = ">=3.0"
11
+ classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: Microsoft :: Windows :: Windows 10", "Operating System :: Microsoft :: Windows :: Windows 11", "Operating System :: POSIX :: Linux", "Operating System :: MacOS",]
12
+ keywords = [ "torch", "ai", "ml", "llm", "installer", "runtime",]
13
+ dynamic = [ "dependencies",]
14
+ [[project.authors]]
15
+ name = "cmdr2"
16
+ email = "secondary.cmdr2@gmail.com"
17
+
18
+ [project.urls]
19
+ Homepage = "https://github.com/easydiffusion/torchruntime"
20
+ "Bug Tracker" = "https://github.com/easydiffusion/torchruntime/issues"
21
+
22
+ [tool.isort]
23
+ profile = "black"
24
+
25
+ [tool.black]
26
+ line-length = 120
27
+ include = "\\.pyi?$"
28
+ exclude = "/(\n \\.git\n | \\.hg\n | \\.mypy_cache\n | \\.tox\n | \\.venv\n | _build\n | buck-out\n | build\n | dist\n)/\n"
29
+
30
+ [tool.pytest.ini_options]
31
+ minversion = "6.0"
32
+ addopts = "-vs"
33
+ testpaths = [ "tests", "integration",]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,5 @@
1
+ import setuptools
2
+
3
+ setuptools.setup(
4
+ install_requires=[],
5
+ )
@@ -0,0 +1,137 @@
1
+ import os
2
+ import pytest
3
+ from torchruntime.configure import set_rocm_env_vars
4
+ from torchruntime.consts import AMD
5
+
6
+
7
+ def create_gpu_info(device_id, device_name):
8
+ return (AMD, "Advanced Micro Devices, Inc. [AMD/ATI]", device_id, device_name)
9
+
10
+
11
+ @pytest.fixture(autouse=True)
12
+ def clean_env():
13
+ # Remove relevant environment variables before each test
14
+ env_vars = [
15
+ "HSA_OVERRIDE_GFX_VERSION",
16
+ "HIP_VISIBLE_DEVICES",
17
+ "ROC_ENABLE_PRE_VEGA",
18
+ "HSA_ENABLE_SDMA",
19
+ "FORCE_FULL_PRECISION",
20
+ ]
21
+
22
+ # Store original values
23
+ original_values = {}
24
+ for var in env_vars:
25
+ if var in os.environ:
26
+ original_values[var] = os.environ[var]
27
+ del os.environ[var]
28
+
29
+ yield
30
+
31
+ # Restore original values and remove any new ones
32
+ for var in env_vars:
33
+ if var in os.environ and var not in original_values:
34
+ del os.environ[var]
35
+ elif var in original_values:
36
+ os.environ[var] = original_values[var]
37
+
38
+
39
+ def test_navi_3_settings():
40
+ gpus = [create_gpu_info("123", "Navi 31 XTX")]
41
+ set_rocm_env_vars(gpus, "rocm6.2")
42
+
43
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "11.0.0"
44
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
45
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
46
+ assert "FORCE_FULL_PRECISION" not in os.environ
47
+
48
+
49
+ def test_navi_2_settings():
50
+ gpus = [create_gpu_info("123", "Navi 21 XTX")]
51
+ set_rocm_env_vars(gpus, "rocm6.2")
52
+
53
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "10.3.0"
54
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
55
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
56
+ assert "FORCE_FULL_PRECISION" not in os.environ
57
+
58
+
59
+ def test_navi_1_settings():
60
+ gpus = [create_gpu_info("123", "Navi 14")]
61
+ set_rocm_env_vars(gpus, "rocm6.2")
62
+
63
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "10.3.0"
64
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
65
+ assert os.environ.get("FORCE_FULL_PRECISION") == "yes"
66
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
67
+
68
+
69
+ def test_vega_2_settings():
70
+ gpus = [create_gpu_info("123", "Vega 20 Radeon VII")]
71
+ set_rocm_env_vars(gpus, "rocm6.2")
72
+
73
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "9.0.6"
74
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
75
+ assert "FORCE_FULL_PRECISION" not in os.environ
76
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
77
+
78
+
79
+ def test_vega_1_settings():
80
+ gpus = [create_gpu_info("123", "Vega 10")]
81
+ set_rocm_env_vars(gpus, "rocm6.2")
82
+
83
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "9.0.0"
84
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
85
+ assert os.environ.get("FORCE_FULL_PRECISION") == "yes"
86
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
87
+
88
+
89
+ def test_ellesmere_settings():
90
+ gpus = [create_gpu_info("123", "Ellesmere RX 580")]
91
+ set_rocm_env_vars(gpus, "rocm6.2")
92
+
93
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "8.0.3"
94
+ assert os.environ.get("ROC_ENABLE_PRE_VEGA") == "1"
95
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0"
96
+ assert "FORCE_FULL_PRECISION" not in os.environ
97
+
98
+
99
+ def test_unknown_gpu_settings():
100
+ gpus = [create_gpu_info("123", "Unknown GPU")]
101
+ set_rocm_env_vars(gpus, "rocm6.2")
102
+
103
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
104
+ assert "HIP_VISIBLE_DEVICES" not in os.environ
105
+ assert "HSA_OVERRIDE_GFX_VERSION" not in os.environ
106
+ assert "FORCE_FULL_PRECISION" not in os.environ
107
+
108
+
109
+ def test_multiple_gpus_same_type():
110
+ gpus = [create_gpu_info("123", "Navi 31 XTX"), create_gpu_info("124", "Navi 31 XT")]
111
+ set_rocm_env_vars(gpus, "rocm6.2")
112
+
113
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "11.0.0"
114
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0,1"
115
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
116
+ assert "FORCE_FULL_PRECISION" not in os.environ
117
+
118
+
119
+ def test_multiple_gpus_mixed_types():
120
+ gpus = [create_gpu_info("123", "Navi 31 XTX"), create_gpu_info("124", "Navi 21 XT")]
121
+ set_rocm_env_vars(gpus, "rocm6.2")
122
+
123
+ # Should use Navi 3 settings since at least one GPU is Navi 3
124
+ assert os.environ.get("HSA_OVERRIDE_GFX_VERSION") == "11.0.0"
125
+ assert os.environ.get("HIP_VISIBLE_DEVICES") == "0,1"
126
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
127
+ assert "FORCE_FULL_PRECISION" not in os.environ
128
+
129
+
130
+ def test_empty_gpu_list():
131
+ gpus = []
132
+ set_rocm_env_vars(gpus, "rocm6.2")
133
+
134
+ assert "ROC_ENABLE_PRE_VEGA" not in os.environ
135
+ assert "HIP_VISIBLE_DEVICES" not in os.environ
136
+ assert "HSA_OVERRIDE_GFX_VERSION" not in os.environ
137
+ assert "FORCE_FULL_PRECISION" not in os.environ
@@ -0,0 +1,50 @@
1
+ from torchruntime.device_db import parse_windows_output, parse_linux_output, parse_macos_output
2
+
3
+
4
+ def test_parse_windows_output():
5
+ output = """
6
+ PCI\\VEN_8086&DEV_591B&SUBSYS_2212103C&REV_04
7
+ PCI\\VEN_10DE&DEV_1C82&SUBSYS_37131462&REV_A1
8
+ PCI\\VEN_10DE&DEV_2504&SUBSYS_881D1043&REV_A1\\4&22AF55FA&0&0008
9
+ """
10
+ expected = [("8086", "591b"), ("10de", "1c82"), ("10de", ("2504"))]
11
+ assert sorted(parse_windows_output(output)) == sorted(expected)
12
+
13
+
14
+ def test_parse_linux_output():
15
+ output = """
16
+ 00:02.0 VGA compatible controller: Intel Corporation HD Graphics 620 (rev 02) [8086:5916]
17
+ 01:00.0 3D controller: NVIDIA Corporation GP108M [GeForce MX150] (rev a1) [10de:1d10]
18
+ """
19
+ expected = [("8086", "5916"), ("10de", "1d10")]
20
+ assert sorted(parse_linux_output(output)) == sorted(expected)
21
+
22
+
23
+ def test_parse_macos_output():
24
+ output = """
25
+ {
26
+ "SPDisplaysDataType": [
27
+ {
28
+ "spdisplays_vendor": "Intel",
29
+ "spdisplays_device-id": "0x5916"
30
+ },
31
+ {
32
+ "spdisplays_vendor": "NVIDIA (0x10de)",
33
+ "spdisplays_device-id": "0x1c82"
34
+ }
35
+ ]
36
+ }
37
+ """
38
+ expected = [("8086", "5916"), ("10de", "1c82")]
39
+ assert sorted(parse_macos_output(output)) == sorted(expected)
40
+
41
+
42
+ def test_parse_macos_output_invalid_json():
43
+ output = "{ invalid_json: true }"
44
+ assert parse_macos_output(output) == []
45
+
46
+
47
+ def test_empty_outputs():
48
+ assert parse_windows_output("") == []
49
+ assert parse_linux_output("") == []
50
+ assert parse_macos_output("") == []
@@ -0,0 +1,50 @@
1
+ # Integration test, which connects to the database and checks for some common devices
2
+
3
+ from torchruntime.device_db import get_device_infos, DEVICE_DB_FILE
4
+
5
+
6
+ def test_db_file_exists():
7
+ import os
8
+ from torchruntime import device_db
9
+
10
+ db_path = os.path.join(os.path.dirname(device_db.__file__), DEVICE_DB_FILE)
11
+ assert os.path.exists(db_path)
12
+
13
+
14
+ def test_get_single_device():
15
+ """Test retrieving a single device."""
16
+ result = get_device_infos([("8086", "56a7")])
17
+ assert len(result) == 1
18
+ assert result[0] == ("8086", "Intel Corporation", "56a7", "DG2 [Arc Xe Graphics]")
19
+
20
+
21
+ def test_get_multiple_devices():
22
+ """Test retrieving multiple devices."""
23
+ input_ids = [("10de", "2786"), ("10de", "2504")]
24
+ result = get_device_infos(input_ids)
25
+ assert len(result) == 2
26
+ assert ("10de", "NVIDIA Corporation", "2786", "AD104 [GeForce RTX 4070]") in result
27
+ assert ("10de", "NVIDIA Corporation", "2504", "GA106 [GeForce RTX 3060 Lite Hash Rate]") in result
28
+
29
+
30
+ def test_get_amd_devices():
31
+ """Test retrieving AMD devices."""
32
+ input_ids = [("1002", "9495"), ("1002", "747e")]
33
+ result = get_device_infos(input_ids)
34
+ assert len(result) == 2
35
+ assert ("1002", "Advanced Micro Devices, Inc. [AMD/ATI]", "9495", "RV730 [Radeon HD 4600 AGP Series]") in result
36
+ assert ("1002", "Advanced Micro Devices, Inc. [AMD/ATI]", "747e", "Navi 32 [Radeon RX 7700 XT / 7800 XT]") in result
37
+
38
+
39
+ def test_get_nonexistent_device():
40
+ """Test retrieving a device that doesn't exist in the database."""
41
+ result = get_device_infos([("ffff", "ffff")])
42
+ assert len(result) == 0
43
+
44
+
45
+ def test_get_mixed_existing_and_nonexistent():
46
+ """Test retrieving a mix of existing and non-existing devices."""
47
+ input_ids = [("8086", "56a7"), ("ffff", "ffff")] # exists # doesn't exist
48
+ result = get_device_infos(input_ids)
49
+ assert len(result) == 1
50
+ assert result[0] == ("8086", "Intel Corporation", "56a7", "DG2 [Arc Xe Graphics]")