causal-conv1d 1.2.2.post1__tar.gz → 1.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal_conv1d
3
- Version: 1.2.2.post1
3
+ Version: 1.4.0
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: License :: OSI Approved :: BSD License
10
10
  Classifier: Operating System :: Unix
11
- Requires-Python: >=3.7
11
+ Requires-Python: >=3.8
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  License-File: AUTHORS
@@ -43,3 +43,16 @@ import torch.nn.functional as F
43
43
 
44
44
  F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
45
45
  ```
46
+
47
+ ## Additional Prerequisites for AMD cards
48
+
49
+ ### Patching ROCm
50
+
51
+ If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
52
+
53
+ 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
54
+
55
+ 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
56
+ ```bash
57
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
58
+ ```
@@ -28,3 +28,16 @@ import torch.nn.functional as F
28
28
 
29
29
  F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
30
30
  ```
31
+
32
+ ## Additional Prerequisites for AMD cards
33
+
34
+ ### Patching ROCm
35
+
36
+ If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
37
+
38
+ 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
39
+
40
+ 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
41
+ ```bash
42
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
43
+ ```
@@ -1,3 +1,3 @@
1
- __version__ = "1.2.2.post1"
1
+ __version__ = "1.4.0"
2
2
 
3
3
  from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
@@ -172,42 +172,68 @@ def causal_conv1d_ref(
172
172
  return out if not return_final_states else (out, final_states_out)
173
173
 
174
174
 
175
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
175
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
176
176
  """
177
- x: (batch, dim)
178
- conv_state: (batch, dim, width)
177
+ x: (batch, dim) or (batch, dim, seqlen)
178
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
179
179
  weight: (dim, width)
180
180
  bias: (dim,)
181
+ cache_seqlens: (batch,), dtype int32.
182
+ If not None, the conv_state is treated as a circular buffer.
183
+ The conv_state will be updated by copying x to the conv_state starting at the index
184
+ @cache_seqlens % state_len.
181
185
 
182
- out: (batch, dim)
186
+ out: (batch, dim) or (batch, dim, seqlen)
183
187
  """
184
188
  if activation not in [None, "silu", "swish"]:
185
189
  raise NotImplementedError("activation must be None, silu, or swish")
186
190
  activation = activation in ["silu", "swish"]
187
- return causal_conv1d_cuda.causal_conv1d_update(
188
- x, conv_state, weight, bias, activation
191
+ unsqueeze = x.dim() == 2
192
+ if unsqueeze:
193
+ x = x.unsqueeze(-1)
194
+ out = causal_conv1d_cuda.causal_conv1d_update(
195
+ x, conv_state, weight, bias, activation, cache_seqlens
189
196
  )
197
+ if unsqueeze:
198
+ out = out.squeeze(-1)
199
+ return out
190
200
 
191
201
 
192
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
202
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
193
203
  """
194
- x: (batch, dim)
195
- conv_state: (batch, dim, width)
204
+ x: (batch, dim) or (batch, dim, seqlen)
205
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
196
206
  weight: (dim, width)
197
207
  bias: (dim,)
208
+ cache_seqlens: (batch,), dtype int32.
209
+ If not None, the conv_state is treated as a circular buffer.
210
+ The conv_state will be updated by copying x to the conv_state starting at the index
211
+ @cache_seqlens % state_len before performing the convolution.
198
212
 
199
- out: (batch, dim)
213
+ out: (batch, dim) or (batch, dim, seqlen)
200
214
  """
201
215
  if activation not in [None, "silu", "swish"]:
202
216
  raise NotImplementedError("activation must be None, silu, or swish")
203
217
  dtype_in = x.dtype
204
- batch, dim = x.shape
218
+ unsqueeze = x.dim() == 2
219
+ if unsqueeze:
220
+ x = x.unsqueeze(-1)
221
+ batch, dim, seqlen = x.shape
205
222
  width = weight.shape[1]
206
- assert conv_state.shape == (batch, dim, width)
223
+ state_len = conv_state.shape[-1]
224
+ assert conv_state.shape == (batch, dim, state_len)
207
225
  assert weight.shape == (dim, width)
208
- conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
209
- conv_state[:, :, -1] = x
210
- out = torch.sum(conv_state * weight, dim=-1) # (B D)
211
- if bias is not None:
212
- out += bias
226
+ if cache_seqlens is None:
227
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
228
+ conv_state.copy_(x_new[:, :, -state_len:])
229
+ else:
230
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
231
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
232
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
233
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
+ conv_state.scatter_(2, copy_idx, x)
236
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
237
+ if unsqueeze:
238
+ out = out.squeeze(-1)
213
239
  return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@@ -0,0 +1,86 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _causal_conv1d_varlen_states(
10
+ X,
11
+ CU_SEQLENS,
12
+ STATES,
13
+ state_len,
14
+ dim,
15
+ stride_x_seqlen, stride_x_dim,
16
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
17
+ BLOCK_M: tl.constexpr,
18
+ BLOCK_N: tl.constexpr
19
+ ):
20
+ batch_idx = tl.program_id(2)
21
+ STATES += batch_idx * stride_states_batch
22
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
+ other=0)
29
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
+ x,
32
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
+
34
+
35
+ def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
+ """
37
+ Forward pass only, does not support backward pass.
38
+ Parameters:
39
+ x: (total_tokens, dim)
40
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
+ If some of those elements belong to a different sequence, the value of the states will be zero.
43
+ Return:
44
+ states: (batch, dim, state_len)
45
+ """
46
+ _, dim = x.shape
47
+ batch = cu_seqlens.shape[0] - 1
48
+ cu_seqlens = cu_seqlens.contiguous()
49
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
+ with torch.cuda.device(x.device.index):
54
+ _causal_conv1d_varlen_states[grid](
55
+ x,
56
+ cu_seqlens,
57
+ states,
58
+ state_len,
59
+ dim,
60
+ x.stride(0), x.stride(1),
61
+ states.stride(0), states.stride(2), states.stride(1),
62
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
+ )
64
+ return states
65
+
66
+
67
+ def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
+ """
69
+ Forward pass only, does not support backward pass.
70
+ Parameters:
71
+ x: (total_tokens, dim)
72
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
+ If some of those elements belong to a different sequence, the value of the states will be zero.
75
+ Return:
76
+ states: (batch, dim, state_len)
77
+ """
78
+ _, dim = x.shape
79
+ batch = cu_seqlens.shape[0] - 1
80
+ cu_seqlens = cu_seqlens.contiguous()
81
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
+ for i in range(batch):
83
+ end_idx = cu_seqlens[i + 1]
84
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
+ return states
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal-conv1d
3
- Version: 1.2.2.post1
3
+ Version: 1.4.0
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: License :: OSI Approved :: BSD License
10
10
  Classifier: Operating System :: Unix
11
- Requires-Python: >=3.7
11
+ Requires-Python: >=3.8
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  License-File: AUTHORS
@@ -43,3 +43,16 @@ import torch.nn.functional as F
43
43
 
44
44
  F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
45
45
  ```
46
+
47
+ ## Additional Prerequisites for AMD cards
48
+
49
+ ### Patching ROCm
50
+
51
+ If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
52
+
53
+ 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
54
+
55
+ 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
56
+ ```bash
57
+ patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
58
+ ```
@@ -4,6 +4,7 @@ README.md
4
4
  setup.py
5
5
  causal_conv1d/__init__.py
6
6
  causal_conv1d/causal_conv1d_interface.py
7
+ causal_conv1d/causal_conv1d_varlen.py
7
8
  causal_conv1d.egg-info/PKG-INFO
8
9
  causal_conv1d.egg-info/SOURCES.txt
9
10
  causal_conv1d.egg-info/dependency_links.txt
@@ -18,7 +18,7 @@ import urllib.error
18
18
  from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
19
19
 
20
20
  import torch
21
- from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
21
+ from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, HIP_HOME
22
22
 
23
23
 
24
24
  with open("README.md", "r", encoding="utf-8") as fh:
@@ -66,6 +66,45 @@ def get_cuda_bare_metal_version(cuda_dir):
66
66
  return raw_output, bare_metal_version
67
67
 
68
68
 
69
+ def get_hip_version(rocm_dir):
70
+
71
+ hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
72
+ try:
73
+ raw_output = subprocess.check_output(
74
+ [hipcc_bin, "--version"], universal_newlines=True
75
+ )
76
+ except Exception as e:
77
+ print(
78
+ f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
79
+ )
80
+ return None, None
81
+
82
+ for line in raw_output.split("\n"):
83
+ if "HIP version" in line:
84
+ rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly
85
+ return line, rocm_version
86
+
87
+ return None, None
88
+
89
+
90
+ def get_torch_hip_version():
91
+ if torch.version.hip:
92
+ return parse(torch.version.hip.split()[-1].replace("-", "+"))
93
+ else:
94
+ return None
95
+
96
+
97
+ def check_if_hip_home_none(global_option: str) -> None:
98
+
99
+ if HIP_HOME is not None:
100
+ return
101
+ # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
102
+ # in that case.
103
+ warnings.warn(
104
+ f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
105
+ )
106
+
107
+
69
108
  def check_if_cuda_home_none(global_option: str) -> None:
70
109
  if CUDA_HOME is not None:
71
110
  return
@@ -85,37 +124,67 @@ def append_nvcc_threads(nvcc_extra_args):
85
124
  cmdclass = {}
86
125
  ext_modules = []
87
126
 
127
+
128
+ HIP_BUILD = bool(torch.version.hip)
129
+
88
130
  if not SKIP_CUDA_BUILD:
131
+
89
132
  print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
90
133
  TORCH_MAJOR = int(torch.__version__.split(".")[0])
91
134
  TORCH_MINOR = int(torch.__version__.split(".")[1])
92
135
 
93
- check_if_cuda_home_none("causal_conv1d")
94
- # Check, if CUDA11 is installed for compute capability 8.0
136
+
95
137
  cc_flag = []
96
- if CUDA_HOME is not None:
97
- _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
98
- if bare_metal_version < Version("11.6"):
99
- raise RuntimeError(
100
- "causal_conv1d is only supported on CUDA 11.6 and above. "
101
- "Note: make sure nvcc has a supported version by running nvcc -V."
102
- )
103
-
104
- cc_flag.append("-gencode")
105
- cc_flag.append("arch=compute_53,code=sm_53")
106
- cc_flag.append("-gencode")
107
- cc_flag.append("arch=compute_62,code=sm_62")
108
- cc_flag.append("-gencode")
109
- cc_flag.append("arch=compute_70,code=sm_70")
110
- cc_flag.append("-gencode")
111
- cc_flag.append("arch=compute_72,code=sm_72")
112
- cc_flag.append("-gencode")
113
- cc_flag.append("arch=compute_80,code=sm_80")
114
- cc_flag.append("-gencode")
115
- cc_flag.append("arch=compute_87,code=sm_87")
116
- if bare_metal_version >= Version("11.8"):
138
+
139
+ if HIP_BUILD:
140
+ check_if_hip_home_none(PACKAGE_NAME)
141
+
142
+ rocm_home = os.getenv("ROCM_PATH")
143
+ _, hip_version = get_hip_version(rocm_home)
144
+
145
+
146
+ if HIP_HOME is not None:
147
+ if hip_version < Version("6.0"):
148
+ raise RuntimeError(
149
+ f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. "
150
+ "Note: make sure HIP has a supported version by running hipcc --version."
151
+ )
152
+ if hip_version == Version("6.0"):
153
+ warnings.warn(
154
+ f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. "
155
+ "Refer to the README.md for detailed instructions.",
156
+ UserWarning
157
+ )
158
+
159
+ cc_flag.append("-DBUILD_PYTHON_PACKAGE")
160
+
161
+ else:
162
+ check_if_cuda_home_none(PACKAGE_NAME)
163
+ # Check, if CUDA11 is installed for compute capability 8.0
164
+
165
+ if CUDA_HOME is not None:
166
+ _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
167
+ if bare_metal_version < Version("11.6"):
168
+ raise RuntimeError(
169
+ f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
170
+ "Note: make sure nvcc has a supported version by running nvcc -V."
171
+ )
172
+
173
+ cc_flag.append("-gencode")
174
+ cc_flag.append("arch=compute_53,code=sm_53")
175
+ cc_flag.append("-gencode")
176
+ cc_flag.append("arch=compute_62,code=sm_62")
177
+ cc_flag.append("-gencode")
178
+ cc_flag.append("arch=compute_70,code=sm_70")
179
+ cc_flag.append("-gencode")
180
+ cc_flag.append("arch=compute_72,code=sm_72")
117
181
  cc_flag.append("-gencode")
118
- cc_flag.append("arch=compute_90,code=sm_90")
182
+ cc_flag.append("arch=compute_80,code=sm_80")
183
+ cc_flag.append("-gencode")
184
+ cc_flag.append("arch=compute_87,code=sm_87")
185
+ if bare_metal_version >= Version("11.8"):
186
+ cc_flag.append("-gencode")
187
+ cc_flag.append("arch=compute_90,code=sm_90")
119
188
 
120
189
  # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
121
190
  # torch._C._GLIBCXX_USE_CXX11_ABI
@@ -123,6 +192,42 @@ if not SKIP_CUDA_BUILD:
123
192
  if FORCE_CXX11_ABI:
124
193
  torch._C._GLIBCXX_USE_CXX11_ABI = True
125
194
 
195
+
196
+ if HIP_BUILD:
197
+ extra_compile_args = {
198
+ "cxx": ["-O3", "-std=c++17"],
199
+ "nvcc": [
200
+ "-O3",
201
+ "-std=c++17",
202
+ f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
203
+ "-U__CUDA_NO_HALF_OPERATORS__",
204
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
205
+ "-fgpu-flush-denormals-to-zero",
206
+ ]
207
+ + cc_flag,
208
+ }
209
+ else:
210
+ extra_compile_args = {
211
+ "cxx": ["-O3"],
212
+ "nvcc": append_nvcc_threads(
213
+ [
214
+ "-O3",
215
+ "-U__CUDA_NO_HALF_OPERATORS__",
216
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
217
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__",
218
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
219
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__",
220
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
221
+ "--expt-relaxed-constexpr",
222
+ "--expt-extended-lambda",
223
+ "--use_fast_math",
224
+ "--ptxas-options=-v",
225
+ "-lineinfo",
226
+ ]
227
+ + cc_flag
228
+ ),
229
+ }
230
+
126
231
  ext_modules.append(
127
232
  CUDAExtension(
128
233
  name="causal_conv1d_cuda",
@@ -132,26 +237,7 @@ if not SKIP_CUDA_BUILD:
132
237
  "csrc/causal_conv1d_bwd.cu",
133
238
  "csrc/causal_conv1d_update.cu",
134
239
  ],
135
- extra_compile_args={
136
- "cxx": ["-O3"],
137
- "nvcc": append_nvcc_threads(
138
- [
139
- "-O3",
140
- "-U__CUDA_NO_HALF_OPERATORS__",
141
- "-U__CUDA_NO_HALF_CONVERSIONS__",
142
- "-U__CUDA_NO_BFLOAT16_OPERATORS__",
143
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
144
- "-U__CUDA_NO_BFLOAT162_OPERATORS__",
145
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
146
- "--expt-relaxed-constexpr",
147
- "--expt-extended-lambda",
148
- "--use_fast_math",
149
- "--ptxas-options=-v",
150
- "-lineinfo",
151
- ]
152
- + cc_flag
153
- ),
154
- },
240
+ extra_compile_args=extra_compile_args,
155
241
  include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"],
156
242
  )
157
243
  )
@@ -169,24 +255,36 @@ def get_package_version():
169
255
 
170
256
 
171
257
  def get_wheel_url():
258
+
172
259
  # Determine the version numbers that will be used to determine the correct wheel
173
- # We're using the CUDA version used to build torch, not the one currently installed
174
- # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
175
- torch_cuda_version = parse(torch.version.cuda)
176
260
  torch_version_raw = parse(torch.__version__)
177
- # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
178
- # to save CI time. Minor versions should be compatible.
179
- torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
261
+
262
+ if HIP_BUILD:
263
+ # We're using the HIP version used to build torch, not the one currently installed
264
+ torch_hip_version = get_torch_hip_version()
265
+ hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
266
+ else:
267
+ # We're using the CUDA version used to build torch, not the one currently installed
268
+ # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
269
+ torch_cuda_version = parse(torch.version.cuda)
270
+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
271
+ # to save CI time. Minor versions should be compatible.
272
+ torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
273
+ cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
274
+
275
+ gpu_compute_version = hip_version if HIP_BUILD else cuda_version
276
+ cuda_or_hip = "hip" if HIP_BUILD else "cu"
277
+
180
278
  python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
181
279
  platform_name = get_platform()
182
280
  causal_conv1d_version = get_package_version()
183
- # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
184
- cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
281
+
185
282
  torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
186
283
  cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
187
284
 
188
285
  # Determine wheel URL based on CUDA version, torch version, python version and OS
189
- wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
286
+ wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
287
+
190
288
  wheel_url = BASE_WHEEL_URL.format(
191
289
  tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
192
290
  )
@@ -260,7 +358,7 @@ setup(
260
358
  else {
261
359
  "bdist_wheel": CachedWheelsCommand,
262
360
  },
263
- python_requires=">=3.7",
361
+ python_requires=">=3.8",
264
362
  install_requires=[
265
363
  "torch",
266
364
  "packaging",