causal-conv1d 1.2.2.post1__tar.gz → 1.3.0.post1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal_conv1d
3
- Version: 1.2.2.post1
3
+ Version: 1.3.0.post1
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -43,3 +43,16 @@ import torch.nn.functional as F
43
43
 
44
44
  F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
45
45
  ```
46
+
47
+ ## Additional Prerequisites for AMD cards
48
+
49
+ ### Patching ROCm
50
+
51
+ If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
52
+
53
+ 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
54
+
55
+ 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
56
+ ```bash
57
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
58
+ ```
@@ -28,3 +28,16 @@ import torch.nn.functional as F
28
28
 
29
29
  F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
30
30
  ```
31
+
32
+ ## Additional Prerequisites for AMD cards
33
+
34
+ ### Patching ROCm
35
+
36
+ If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
37
+
38
+ 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
39
+
40
+ 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
41
+ ```bash
42
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
43
+ ```
@@ -1,3 +1,3 @@
1
- __version__ = "1.2.2.post1"
1
+ __version__ = "1.3.0.post1"
2
2
 
3
3
  from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal-conv1d
3
- Version: 1.2.2.post1
3
+ Version: 1.3.0.post1
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -43,3 +43,16 @@ import torch.nn.functional as F
43
43
 
44
44
  F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
45
45
  ```
46
+
47
+ ## Additional Prerequisites for AMD cards
48
+
49
+ ### Patching ROCm
50
+
51
+ If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
52
+
53
+ 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
54
+
55
+ 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
56
+ ```bash
57
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
58
+ ```
@@ -18,7 +18,7 @@ import urllib.error
18
18
  from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
19
19
 
20
20
  import torch
21
- from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
21
+ from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, HIP_HOME
22
22
 
23
23
 
24
24
  with open("README.md", "r", encoding="utf-8") as fh:
@@ -66,6 +66,45 @@ def get_cuda_bare_metal_version(cuda_dir):
66
66
  return raw_output, bare_metal_version
67
67
 
68
68
 
69
+ def get_hip_version(rocm_dir):
70
+
71
+ hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
72
+ try:
73
+ raw_output = subprocess.check_output(
74
+ [hipcc_bin, "--version"], universal_newlines=True
75
+ )
76
+ except Exception as e:
77
+ print(
78
+ f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
79
+ )
80
+ return None, None
81
+
82
+ for line in raw_output.split("\n"):
83
+ if "HIP version" in line:
84
+ rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly
85
+ return line, rocm_version
86
+
87
+ return None, None
88
+
89
+
90
+ def get_torch_hip_version():
91
+ if torch.version.hip:
92
+ return parse(torch.version.hip.split()[-1].replace("-", "+"))
93
+ else:
94
+ return None
95
+
96
+
97
+ def check_if_hip_home_none(global_option: str) -> None:
98
+
99
+ if HIP_HOME is not None:
100
+ return
101
+ # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
102
+ # in that case.
103
+ warnings.warn(
104
+ f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
105
+ )
106
+
107
+
69
108
  def check_if_cuda_home_none(global_option: str) -> None:
70
109
  if CUDA_HOME is not None:
71
110
  return
@@ -85,37 +124,67 @@ def append_nvcc_threads(nvcc_extra_args):
85
124
  cmdclass = {}
86
125
  ext_modules = []
87
126
 
127
+
128
+ HIP_BUILD = bool(torch.version.hip)
129
+
88
130
  if not SKIP_CUDA_BUILD:
131
+
89
132
  print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
90
133
  TORCH_MAJOR = int(torch.__version__.split(".")[0])
91
134
  TORCH_MINOR = int(torch.__version__.split(".")[1])
92
135
 
93
- check_if_cuda_home_none("causal_conv1d")
94
- # Check, if CUDA11 is installed for compute capability 8.0
136
+
95
137
  cc_flag = []
96
- if CUDA_HOME is not None:
97
- _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
98
- if bare_metal_version < Version("11.6"):
99
- raise RuntimeError(
100
- "causal_conv1d is only supported on CUDA 11.6 and above. "
101
- "Note: make sure nvcc has a supported version by running nvcc -V."
102
- )
103
-
104
- cc_flag.append("-gencode")
105
- cc_flag.append("arch=compute_53,code=sm_53")
106
- cc_flag.append("-gencode")
107
- cc_flag.append("arch=compute_62,code=sm_62")
108
- cc_flag.append("-gencode")
109
- cc_flag.append("arch=compute_70,code=sm_70")
110
- cc_flag.append("-gencode")
111
- cc_flag.append("arch=compute_72,code=sm_72")
112
- cc_flag.append("-gencode")
113
- cc_flag.append("arch=compute_80,code=sm_80")
114
- cc_flag.append("-gencode")
115
- cc_flag.append("arch=compute_87,code=sm_87")
116
- if bare_metal_version >= Version("11.8"):
138
+
139
+ if HIP_BUILD:
140
+ check_if_hip_home_none(PACKAGE_NAME)
141
+
142
+ rocm_home = os.getenv("ROCM_PATH")
143
+ _, hip_version = get_hip_version(rocm_home)
144
+
145
+
146
+ if HIP_HOME is not None:
147
+ if hip_version < Version("6.0"):
148
+ raise RuntimeError(
149
+ f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. "
150
+ "Note: make sure HIP has a supported version by running hipcc --version."
151
+ )
152
+ if hip_version == Version("6.0"):
153
+ warnings.warn(
154
+ f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. "
155
+ "Refer to the README.md for detailed instructions.",
156
+ UserWarning
157
+ )
158
+
159
+ cc_flag.append("-DBUILD_PYTHON_PACKAGE")
160
+
161
+ else:
162
+ check_if_cuda_home_none(PACKAGE_NAME)
163
+ # Check, if CUDA11 is installed for compute capability 8.0
164
+
165
+ if CUDA_HOME is not None:
166
+ _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
167
+ if bare_metal_version < Version("11.6"):
168
+ raise RuntimeError(
169
+ f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
170
+ "Note: make sure nvcc has a supported version by running nvcc -V."
171
+ )
172
+
173
+ cc_flag.append("-gencode")
174
+ cc_flag.append("arch=compute_53,code=sm_53")
175
+ cc_flag.append("-gencode")
176
+ cc_flag.append("arch=compute_62,code=sm_62")
177
+ cc_flag.append("-gencode")
178
+ cc_flag.append("arch=compute_70,code=sm_70")
179
+ cc_flag.append("-gencode")
180
+ cc_flag.append("arch=compute_72,code=sm_72")
117
181
  cc_flag.append("-gencode")
118
- cc_flag.append("arch=compute_90,code=sm_90")
182
+ cc_flag.append("arch=compute_80,code=sm_80")
183
+ cc_flag.append("-gencode")
184
+ cc_flag.append("arch=compute_87,code=sm_87")
185
+ if bare_metal_version >= Version("11.8"):
186
+ cc_flag.append("-gencode")
187
+ cc_flag.append("arch=compute_90,code=sm_90")
119
188
 
120
189
  # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
121
190
  # torch._C._GLIBCXX_USE_CXX11_ABI
@@ -123,6 +192,43 @@ if not SKIP_CUDA_BUILD:
123
192
  if FORCE_CXX11_ABI:
124
193
  torch._C._GLIBCXX_USE_CXX11_ABI = True
125
194
 
195
+
196
+ if HIP_BUILD:
197
+ extra_compile_args = {
198
+ "cxx": ["-O3", "-std=c++17"],
199
+ "nvcc": [
200
+ "-O3",
201
+ "-std=c++17",
202
+ f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
203
+ "-U__CUDA_NO_HALF_OPERATORS__",
204
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
205
+ "-DCK_FMHA_FWD_FAST_EXP2=1",
206
+ "-fgpu-flush-denormals-to-zero",
207
+ ]
208
+ + cc_flag,
209
+ }
210
+ else:
211
+ extra_compile_args = {
212
+ "cxx": ["-O3"],
213
+ "nvcc": append_nvcc_threads(
214
+ [
215
+ "-O3",
216
+ "-U__CUDA_NO_HALF_OPERATORS__",
217
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
218
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__",
219
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
220
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__",
221
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
222
+ "--expt-relaxed-constexpr",
223
+ "--expt-extended-lambda",
224
+ "--use_fast_math",
225
+ "--ptxas-options=-v",
226
+ "-lineinfo",
227
+ ]
228
+ + cc_flag
229
+ ),
230
+ }
231
+
126
232
  ext_modules.append(
127
233
  CUDAExtension(
128
234
  name="causal_conv1d_cuda",
@@ -132,26 +238,7 @@ if not SKIP_CUDA_BUILD:
132
238
  "csrc/causal_conv1d_bwd.cu",
133
239
  "csrc/causal_conv1d_update.cu",
134
240
  ],
135
- extra_compile_args={
136
- "cxx": ["-O3"],
137
- "nvcc": append_nvcc_threads(
138
- [
139
- "-O3",
140
- "-U__CUDA_NO_HALF_OPERATORS__",
141
- "-U__CUDA_NO_HALF_CONVERSIONS__",
142
- "-U__CUDA_NO_BFLOAT16_OPERATORS__",
143
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
144
- "-U__CUDA_NO_BFLOAT162_OPERATORS__",
145
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
146
- "--expt-relaxed-constexpr",
147
- "--expt-extended-lambda",
148
- "--use_fast_math",
149
- "--ptxas-options=-v",
150
- "-lineinfo",
151
- ]
152
- + cc_flag
153
- ),
154
- },
241
+ extra_compile_args=extra_compile_args,
155
242
  include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"],
156
243
  )
157
244
  )
@@ -169,24 +256,36 @@ def get_package_version():
169
256
 
170
257
 
171
258
  def get_wheel_url():
259
+
172
260
  # Determine the version numbers that will be used to determine the correct wheel
173
- # We're using the CUDA version used to build torch, not the one currently installed
174
- # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
175
- torch_cuda_version = parse(torch.version.cuda)
176
261
  torch_version_raw = parse(torch.__version__)
177
- # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
178
- # to save CI time. Minor versions should be compatible.
179
- torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
262
+
263
+ if HIP_BUILD:
264
+ # We're using the HIP version used to build torch, not the one currently installed
265
+ torch_hip_version = get_torch_hip_version()
266
+ hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
267
+ else:
268
+ # We're using the CUDA version used to build torch, not the one currently installed
269
+ # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
270
+ torch_cuda_version = parse(torch.version.cuda)
271
+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
272
+ # to save CI time. Minor versions should be compatible.
273
+ torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
274
+ cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
275
+
276
+ gpu_compute_version = hip_version if HIP_BUILD else cuda_version
277
+ cuda_or_hip = "hip" if HIP_BUILD else "cu"
278
+
180
279
  python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
181
280
  platform_name = get_platform()
182
281
  causal_conv1d_version = get_package_version()
183
- # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
184
- cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
282
+
185
283
  torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
186
284
  cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
187
285
 
188
286
  # Determine wheel URL based on CUDA version, torch version, python version and OS
189
- wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
287
+ wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
288
+
190
289
  wheel_url = BASE_WHEEL_URL.format(
191
290
  tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
192
291
  )