causal-conv1d 1.2.2.post1__tar.gz → 1.3.0.post1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/PKG-INFO +14 -1
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/README.md +13 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d/__init__.py +1 -1
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d.egg-info/PKG-INFO +14 -1
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/setup.py +153 -54
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/AUTHORS +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/LICENSE +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d/causal_conv1d_interface.py +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d.egg-info/SOURCES.txt +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d.egg-info/dependency_links.txt +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d.egg-info/requires.txt +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d.egg-info/top_level.txt +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: causal_conv1d
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.3.0.post1
|
4
4
|
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
|
5
5
|
Home-page: https://github.com/Dao-AILab/causal-conv1d
|
6
6
|
Author: Tri Dao
|
@@ -43,3 +43,16 @@ import torch.nn.functional as F
|
|
43
43
|
|
44
44
|
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
45
45
|
```
|
46
|
+
|
47
|
+
## Additional Prerequisites for AMD cards
|
48
|
+
|
49
|
+
### Patching ROCm
|
50
|
+
|
51
|
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
52
|
+
|
53
|
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
54
|
+
|
55
|
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
56
|
+
```bash
|
57
|
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
58
|
+
```
|
@@ -28,3 +28,16 @@ import torch.nn.functional as F
|
|
28
28
|
|
29
29
|
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
30
30
|
```
|
31
|
+
|
32
|
+
## Additional Prerequisites for AMD cards
|
33
|
+
|
34
|
+
### Patching ROCm
|
35
|
+
|
36
|
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
37
|
+
|
38
|
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
39
|
+
|
40
|
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
41
|
+
```bash
|
42
|
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
43
|
+
```
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: causal-conv1d
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.3.0.post1
|
4
4
|
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
|
5
5
|
Home-page: https://github.com/Dao-AILab/causal-conv1d
|
6
6
|
Author: Tri Dao
|
@@ -43,3 +43,16 @@ import torch.nn.functional as F
|
|
43
43
|
|
44
44
|
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
45
45
|
```
|
46
|
+
|
47
|
+
## Additional Prerequisites for AMD cards
|
48
|
+
|
49
|
+
### Patching ROCm
|
50
|
+
|
51
|
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
52
|
+
|
53
|
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
54
|
+
|
55
|
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
56
|
+
```bash
|
57
|
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
58
|
+
```
|
@@ -18,7 +18,7 @@ import urllib.error
|
|
18
18
|
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
19
19
|
|
20
20
|
import torch
|
21
|
-
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
21
|
+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, HIP_HOME
|
22
22
|
|
23
23
|
|
24
24
|
with open("README.md", "r", encoding="utf-8") as fh:
|
@@ -66,6 +66,45 @@ def get_cuda_bare_metal_version(cuda_dir):
|
|
66
66
|
return raw_output, bare_metal_version
|
67
67
|
|
68
68
|
|
69
|
+
def get_hip_version(rocm_dir):
|
70
|
+
|
71
|
+
hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
|
72
|
+
try:
|
73
|
+
raw_output = subprocess.check_output(
|
74
|
+
[hipcc_bin, "--version"], universal_newlines=True
|
75
|
+
)
|
76
|
+
except Exception as e:
|
77
|
+
print(
|
78
|
+
f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
|
79
|
+
)
|
80
|
+
return None, None
|
81
|
+
|
82
|
+
for line in raw_output.split("\n"):
|
83
|
+
if "HIP version" in line:
|
84
|
+
rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly
|
85
|
+
return line, rocm_version
|
86
|
+
|
87
|
+
return None, None
|
88
|
+
|
89
|
+
|
90
|
+
def get_torch_hip_version():
|
91
|
+
if torch.version.hip:
|
92
|
+
return parse(torch.version.hip.split()[-1].replace("-", "+"))
|
93
|
+
else:
|
94
|
+
return None
|
95
|
+
|
96
|
+
|
97
|
+
def check_if_hip_home_none(global_option: str) -> None:
|
98
|
+
|
99
|
+
if HIP_HOME is not None:
|
100
|
+
return
|
101
|
+
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
|
102
|
+
# in that case.
|
103
|
+
warnings.warn(
|
104
|
+
f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
|
105
|
+
)
|
106
|
+
|
107
|
+
|
69
108
|
def check_if_cuda_home_none(global_option: str) -> None:
|
70
109
|
if CUDA_HOME is not None:
|
71
110
|
return
|
@@ -85,37 +124,67 @@ def append_nvcc_threads(nvcc_extra_args):
|
|
85
124
|
cmdclass = {}
|
86
125
|
ext_modules = []
|
87
126
|
|
127
|
+
|
128
|
+
HIP_BUILD = bool(torch.version.hip)
|
129
|
+
|
88
130
|
if not SKIP_CUDA_BUILD:
|
131
|
+
|
89
132
|
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
90
133
|
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
91
134
|
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
92
135
|
|
93
|
-
|
94
|
-
# Check, if CUDA11 is installed for compute capability 8.0
|
136
|
+
|
95
137
|
cc_flag = []
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
138
|
+
|
139
|
+
if HIP_BUILD:
|
140
|
+
check_if_hip_home_none(PACKAGE_NAME)
|
141
|
+
|
142
|
+
rocm_home = os.getenv("ROCM_PATH")
|
143
|
+
_, hip_version = get_hip_version(rocm_home)
|
144
|
+
|
145
|
+
|
146
|
+
if HIP_HOME is not None:
|
147
|
+
if hip_version < Version("6.0"):
|
148
|
+
raise RuntimeError(
|
149
|
+
f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. "
|
150
|
+
"Note: make sure HIP has a supported version by running hipcc --version."
|
151
|
+
)
|
152
|
+
if hip_version == Version("6.0"):
|
153
|
+
warnings.warn(
|
154
|
+
f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. "
|
155
|
+
"Refer to the README.md for detailed instructions.",
|
156
|
+
UserWarning
|
157
|
+
)
|
158
|
+
|
159
|
+
cc_flag.append("-DBUILD_PYTHON_PACKAGE")
|
160
|
+
|
161
|
+
else:
|
162
|
+
check_if_cuda_home_none(PACKAGE_NAME)
|
163
|
+
# Check, if CUDA11 is installed for compute capability 8.0
|
164
|
+
|
165
|
+
if CUDA_HOME is not None:
|
166
|
+
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
167
|
+
if bare_metal_version < Version("11.6"):
|
168
|
+
raise RuntimeError(
|
169
|
+
f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
|
170
|
+
"Note: make sure nvcc has a supported version by running nvcc -V."
|
171
|
+
)
|
172
|
+
|
173
|
+
cc_flag.append("-gencode")
|
174
|
+
cc_flag.append("arch=compute_53,code=sm_53")
|
175
|
+
cc_flag.append("-gencode")
|
176
|
+
cc_flag.append("arch=compute_62,code=sm_62")
|
177
|
+
cc_flag.append("-gencode")
|
178
|
+
cc_flag.append("arch=compute_70,code=sm_70")
|
179
|
+
cc_flag.append("-gencode")
|
180
|
+
cc_flag.append("arch=compute_72,code=sm_72")
|
117
181
|
cc_flag.append("-gencode")
|
118
|
-
cc_flag.append("arch=
|
182
|
+
cc_flag.append("arch=compute_80,code=sm_80")
|
183
|
+
cc_flag.append("-gencode")
|
184
|
+
cc_flag.append("arch=compute_87,code=sm_87")
|
185
|
+
if bare_metal_version >= Version("11.8"):
|
186
|
+
cc_flag.append("-gencode")
|
187
|
+
cc_flag.append("arch=compute_90,code=sm_90")
|
119
188
|
|
120
189
|
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
121
190
|
# torch._C._GLIBCXX_USE_CXX11_ABI
|
@@ -123,6 +192,43 @@ if not SKIP_CUDA_BUILD:
|
|
123
192
|
if FORCE_CXX11_ABI:
|
124
193
|
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
125
194
|
|
195
|
+
|
196
|
+
if HIP_BUILD:
|
197
|
+
extra_compile_args = {
|
198
|
+
"cxx": ["-O3", "-std=c++17"],
|
199
|
+
"nvcc": [
|
200
|
+
"-O3",
|
201
|
+
"-std=c++17",
|
202
|
+
f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
|
203
|
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
204
|
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
205
|
+
"-DCK_FMHA_FWD_FAST_EXP2=1",
|
206
|
+
"-fgpu-flush-denormals-to-zero",
|
207
|
+
]
|
208
|
+
+ cc_flag,
|
209
|
+
}
|
210
|
+
else:
|
211
|
+
extra_compile_args = {
|
212
|
+
"cxx": ["-O3"],
|
213
|
+
"nvcc": append_nvcc_threads(
|
214
|
+
[
|
215
|
+
"-O3",
|
216
|
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
217
|
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
218
|
+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
219
|
+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
220
|
+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
221
|
+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
222
|
+
"--expt-relaxed-constexpr",
|
223
|
+
"--expt-extended-lambda",
|
224
|
+
"--use_fast_math",
|
225
|
+
"--ptxas-options=-v",
|
226
|
+
"-lineinfo",
|
227
|
+
]
|
228
|
+
+ cc_flag
|
229
|
+
),
|
230
|
+
}
|
231
|
+
|
126
232
|
ext_modules.append(
|
127
233
|
CUDAExtension(
|
128
234
|
name="causal_conv1d_cuda",
|
@@ -132,26 +238,7 @@ if not SKIP_CUDA_BUILD:
|
|
132
238
|
"csrc/causal_conv1d_bwd.cu",
|
133
239
|
"csrc/causal_conv1d_update.cu",
|
134
240
|
],
|
135
|
-
extra_compile_args=
|
136
|
-
"cxx": ["-O3"],
|
137
|
-
"nvcc": append_nvcc_threads(
|
138
|
-
[
|
139
|
-
"-O3",
|
140
|
-
"-U__CUDA_NO_HALF_OPERATORS__",
|
141
|
-
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
142
|
-
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
143
|
-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
144
|
-
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
145
|
-
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
146
|
-
"--expt-relaxed-constexpr",
|
147
|
-
"--expt-extended-lambda",
|
148
|
-
"--use_fast_math",
|
149
|
-
"--ptxas-options=-v",
|
150
|
-
"-lineinfo",
|
151
|
-
]
|
152
|
-
+ cc_flag
|
153
|
-
),
|
154
|
-
},
|
241
|
+
extra_compile_args=extra_compile_args,
|
155
242
|
include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"],
|
156
243
|
)
|
157
244
|
)
|
@@ -169,24 +256,36 @@ def get_package_version():
|
|
169
256
|
|
170
257
|
|
171
258
|
def get_wheel_url():
|
259
|
+
|
172
260
|
# Determine the version numbers that will be used to determine the correct wheel
|
173
|
-
# We're using the CUDA version used to build torch, not the one currently installed
|
174
|
-
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
175
|
-
torch_cuda_version = parse(torch.version.cuda)
|
176
261
|
torch_version_raw = parse(torch.__version__)
|
177
|
-
|
178
|
-
|
179
|
-
|
262
|
+
|
263
|
+
if HIP_BUILD:
|
264
|
+
# We're using the HIP version used to build torch, not the one currently installed
|
265
|
+
torch_hip_version = get_torch_hip_version()
|
266
|
+
hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
|
267
|
+
else:
|
268
|
+
# We're using the CUDA version used to build torch, not the one currently installed
|
269
|
+
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
270
|
+
torch_cuda_version = parse(torch.version.cuda)
|
271
|
+
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
|
272
|
+
# to save CI time. Minor versions should be compatible.
|
273
|
+
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
|
274
|
+
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
275
|
+
|
276
|
+
gpu_compute_version = hip_version if HIP_BUILD else cuda_version
|
277
|
+
cuda_or_hip = "hip" if HIP_BUILD else "cu"
|
278
|
+
|
180
279
|
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
181
280
|
platform_name = get_platform()
|
182
281
|
causal_conv1d_version = get_package_version()
|
183
|
-
|
184
|
-
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
282
|
+
|
185
283
|
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
|
186
284
|
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
187
285
|
|
188
286
|
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
189
|
-
wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+
|
287
|
+
wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
288
|
+
|
190
289
|
wheel_url = BASE_WHEEL_URL.format(
|
191
290
|
tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
|
192
291
|
)
|
File without changes
|
File without changes
|
{causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d/causal_conv1d_interface.py
RENAMED
File without changes
|
File without changes
|
{causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
{causal_conv1d-1.2.2.post1 → causal_conv1d-1.3.0.post1}/causal_conv1d.egg-info/top_level.txt
RENAMED
File without changes
|
File without changes
|