causal-conv1d 1.2.2.post1__tar.gz → 1.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/PKG-INFO +15 -2
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/README.md +13 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/causal_conv1d/__init__.py +1 -1
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/causal_conv1d/causal_conv1d_interface.py +43 -17
- causal_conv1d-1.4.0/causal_conv1d/causal_conv1d_varlen.py +86 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/causal_conv1d.egg-info/PKG-INFO +15 -2
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/causal_conv1d.egg-info/SOURCES.txt +1 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/setup.py +153 -55
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/AUTHORS +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/LICENSE +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/causal_conv1d.egg-info/dependency_links.txt +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/causal_conv1d.egg-info/requires.txt +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/causal_conv1d.egg-info/top_level.txt +0 -0
- {causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: causal_conv1d
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.4.0
|
4
4
|
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
|
5
5
|
Home-page: https://github.com/Dao-AILab/causal-conv1d
|
6
6
|
Author: Tri Dao
|
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
9
9
|
Classifier: License :: OSI Approved :: BSD License
|
10
10
|
Classifier: Operating System :: Unix
|
11
|
-
Requires-Python: >=3.
|
11
|
+
Requires-Python: >=3.8
|
12
12
|
Description-Content-Type: text/markdown
|
13
13
|
License-File: LICENSE
|
14
14
|
License-File: AUTHORS
|
@@ -43,3 +43,16 @@ import torch.nn.functional as F
|
|
43
43
|
|
44
44
|
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
45
45
|
```
|
46
|
+
|
47
|
+
## Additional Prerequisites for AMD cards
|
48
|
+
|
49
|
+
### Patching ROCm
|
50
|
+
|
51
|
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
52
|
+
|
53
|
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
54
|
+
|
55
|
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
56
|
+
```bash
|
57
|
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
58
|
+
```
|
@@ -28,3 +28,16 @@ import torch.nn.functional as F
|
|
28
28
|
|
29
29
|
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
30
30
|
```
|
31
|
+
|
32
|
+
## Additional Prerequisites for AMD cards
|
33
|
+
|
34
|
+
### Patching ROCm
|
35
|
+
|
36
|
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
37
|
+
|
38
|
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
39
|
+
|
40
|
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
41
|
+
```bash
|
42
|
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
43
|
+
```
|
@@ -172,42 +172,68 @@ def causal_conv1d_ref(
|
|
172
172
|
return out if not return_final_states else (out, final_states_out)
|
173
173
|
|
174
174
|
|
175
|
-
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
|
175
|
+
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
176
176
|
"""
|
177
|
-
x: (batch, dim)
|
178
|
-
conv_state: (batch, dim, width
|
177
|
+
x: (batch, dim) or (batch, dim, seqlen)
|
178
|
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
179
179
|
weight: (dim, width)
|
180
180
|
bias: (dim,)
|
181
|
+
cache_seqlens: (batch,), dtype int32.
|
182
|
+
If not None, the conv_state is treated as a circular buffer.
|
183
|
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
184
|
+
@cache_seqlens % state_len.
|
181
185
|
|
182
|
-
out: (batch, dim)
|
186
|
+
out: (batch, dim) or (batch, dim, seqlen)
|
183
187
|
"""
|
184
188
|
if activation not in [None, "silu", "swish"]:
|
185
189
|
raise NotImplementedError("activation must be None, silu, or swish")
|
186
190
|
activation = activation in ["silu", "swish"]
|
187
|
-
|
188
|
-
|
191
|
+
unsqueeze = x.dim() == 2
|
192
|
+
if unsqueeze:
|
193
|
+
x = x.unsqueeze(-1)
|
194
|
+
out = causal_conv1d_cuda.causal_conv1d_update(
|
195
|
+
x, conv_state, weight, bias, activation, cache_seqlens
|
189
196
|
)
|
197
|
+
if unsqueeze:
|
198
|
+
out = out.squeeze(-1)
|
199
|
+
return out
|
190
200
|
|
191
201
|
|
192
|
-
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
|
202
|
+
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
193
203
|
"""
|
194
|
-
x: (batch, dim)
|
195
|
-
conv_state: (batch, dim, width
|
204
|
+
x: (batch, dim) or (batch, dim, seqlen)
|
205
|
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
196
206
|
weight: (dim, width)
|
197
207
|
bias: (dim,)
|
208
|
+
cache_seqlens: (batch,), dtype int32.
|
209
|
+
If not None, the conv_state is treated as a circular buffer.
|
210
|
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
211
|
+
@cache_seqlens % state_len before performing the convolution.
|
198
212
|
|
199
|
-
out: (batch, dim)
|
213
|
+
out: (batch, dim) or (batch, dim, seqlen)
|
200
214
|
"""
|
201
215
|
if activation not in [None, "silu", "swish"]:
|
202
216
|
raise NotImplementedError("activation must be None, silu, or swish")
|
203
217
|
dtype_in = x.dtype
|
204
|
-
|
218
|
+
unsqueeze = x.dim() == 2
|
219
|
+
if unsqueeze:
|
220
|
+
x = x.unsqueeze(-1)
|
221
|
+
batch, dim, seqlen = x.shape
|
205
222
|
width = weight.shape[1]
|
206
|
-
|
223
|
+
state_len = conv_state.shape[-1]
|
224
|
+
assert conv_state.shape == (batch, dim, state_len)
|
207
225
|
assert weight.shape == (dim, width)
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
226
|
+
if cache_seqlens is None:
|
227
|
+
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
|
228
|
+
conv_state.copy_(x_new[:, :, -state_len:])
|
229
|
+
else:
|
230
|
+
width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
231
|
+
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
232
|
+
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
233
|
+
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
234
|
+
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
235
|
+
conv_state.scatter_(2, copy_idx, x)
|
236
|
+
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
|
237
|
+
if unsqueeze:
|
238
|
+
out = out.squeeze(-1)
|
213
239
|
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import Tensor
|
3
|
+
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _causal_conv1d_varlen_states(
|
10
|
+
X,
|
11
|
+
CU_SEQLENS,
|
12
|
+
STATES,
|
13
|
+
state_len,
|
14
|
+
dim,
|
15
|
+
stride_x_seqlen, stride_x_dim,
|
16
|
+
stride_states_batch, stride_states_seqlen, stride_states_dim,
|
17
|
+
BLOCK_M: tl.constexpr,
|
18
|
+
BLOCK_N: tl.constexpr
|
19
|
+
):
|
20
|
+
batch_idx = tl.program_id(2)
|
21
|
+
STATES += batch_idx * stride_states_batch
|
22
|
+
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
|
23
|
+
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
|
24
|
+
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
25
|
+
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
|
26
|
+
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
|
27
|
+
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
|
28
|
+
other=0)
|
29
|
+
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
30
|
+
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
|
31
|
+
x,
|
32
|
+
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
|
33
|
+
|
34
|
+
|
35
|
+
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
36
|
+
"""
|
37
|
+
Forward pass only, does not support backward pass.
|
38
|
+
Parameters:
|
39
|
+
x: (total_tokens, dim)
|
40
|
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
41
|
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
42
|
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
43
|
+
Return:
|
44
|
+
states: (batch, dim, state_len)
|
45
|
+
"""
|
46
|
+
_, dim = x.shape
|
47
|
+
batch = cu_seqlens.shape[0] - 1
|
48
|
+
cu_seqlens = cu_seqlens.contiguous()
|
49
|
+
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
50
|
+
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
|
51
|
+
BLOCK_N = min(triton.next_power_of_2(dim), 256)
|
52
|
+
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
|
53
|
+
with torch.cuda.device(x.device.index):
|
54
|
+
_causal_conv1d_varlen_states[grid](
|
55
|
+
x,
|
56
|
+
cu_seqlens,
|
57
|
+
states,
|
58
|
+
state_len,
|
59
|
+
dim,
|
60
|
+
x.stride(0), x.stride(1),
|
61
|
+
states.stride(0), states.stride(2), states.stride(1),
|
62
|
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
63
|
+
)
|
64
|
+
return states
|
65
|
+
|
66
|
+
|
67
|
+
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
68
|
+
"""
|
69
|
+
Forward pass only, does not support backward pass.
|
70
|
+
Parameters:
|
71
|
+
x: (total_tokens, dim)
|
72
|
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
73
|
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
74
|
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
75
|
+
Return:
|
76
|
+
states: (batch, dim, state_len)
|
77
|
+
"""
|
78
|
+
_, dim = x.shape
|
79
|
+
batch = cu_seqlens.shape[0] - 1
|
80
|
+
cu_seqlens = cu_seqlens.contiguous()
|
81
|
+
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
82
|
+
for i in range(batch):
|
83
|
+
end_idx = cu_seqlens[i + 1]
|
84
|
+
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
|
85
|
+
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
|
86
|
+
return states
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: causal-conv1d
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.4.0
|
4
4
|
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
|
5
5
|
Home-page: https://github.com/Dao-AILab/causal-conv1d
|
6
6
|
Author: Tri Dao
|
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
9
9
|
Classifier: License :: OSI Approved :: BSD License
|
10
10
|
Classifier: Operating System :: Unix
|
11
|
-
Requires-Python: >=3.
|
11
|
+
Requires-Python: >=3.8
|
12
12
|
Description-Content-Type: text/markdown
|
13
13
|
License-File: LICENSE
|
14
14
|
License-File: AUTHORS
|
@@ -43,3 +43,16 @@ import torch.nn.functional as F
|
|
43
43
|
|
44
44
|
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
45
45
|
```
|
46
|
+
|
47
|
+
## Additional Prerequisites for AMD cards
|
48
|
+
|
49
|
+
### Patching ROCm
|
50
|
+
|
51
|
+
If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
|
52
|
+
|
53
|
+
1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
|
54
|
+
|
55
|
+
2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
|
56
|
+
```bash
|
57
|
+
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
|
58
|
+
```
|
@@ -18,7 +18,7 @@ import urllib.error
|
|
18
18
|
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
19
19
|
|
20
20
|
import torch
|
21
|
-
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
21
|
+
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, HIP_HOME
|
22
22
|
|
23
23
|
|
24
24
|
with open("README.md", "r", encoding="utf-8") as fh:
|
@@ -66,6 +66,45 @@ def get_cuda_bare_metal_version(cuda_dir):
|
|
66
66
|
return raw_output, bare_metal_version
|
67
67
|
|
68
68
|
|
69
|
+
def get_hip_version(rocm_dir):
|
70
|
+
|
71
|
+
hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc")
|
72
|
+
try:
|
73
|
+
raw_output = subprocess.check_output(
|
74
|
+
[hipcc_bin, "--version"], universal_newlines=True
|
75
|
+
)
|
76
|
+
except Exception as e:
|
77
|
+
print(
|
78
|
+
f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}"
|
79
|
+
)
|
80
|
+
return None, None
|
81
|
+
|
82
|
+
for line in raw_output.split("\n"):
|
83
|
+
if "HIP version" in line:
|
84
|
+
rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly
|
85
|
+
return line, rocm_version
|
86
|
+
|
87
|
+
return None, None
|
88
|
+
|
89
|
+
|
90
|
+
def get_torch_hip_version():
|
91
|
+
if torch.version.hip:
|
92
|
+
return parse(torch.version.hip.split()[-1].replace("-", "+"))
|
93
|
+
else:
|
94
|
+
return None
|
95
|
+
|
96
|
+
|
97
|
+
def check_if_hip_home_none(global_option: str) -> None:
|
98
|
+
|
99
|
+
if HIP_HOME is not None:
|
100
|
+
return
|
101
|
+
# warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
|
102
|
+
# in that case.
|
103
|
+
warnings.warn(
|
104
|
+
f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?"
|
105
|
+
)
|
106
|
+
|
107
|
+
|
69
108
|
def check_if_cuda_home_none(global_option: str) -> None:
|
70
109
|
if CUDA_HOME is not None:
|
71
110
|
return
|
@@ -85,37 +124,67 @@ def append_nvcc_threads(nvcc_extra_args):
|
|
85
124
|
cmdclass = {}
|
86
125
|
ext_modules = []
|
87
126
|
|
127
|
+
|
128
|
+
HIP_BUILD = bool(torch.version.hip)
|
129
|
+
|
88
130
|
if not SKIP_CUDA_BUILD:
|
131
|
+
|
89
132
|
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
90
133
|
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
91
134
|
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
92
135
|
|
93
|
-
|
94
|
-
# Check, if CUDA11 is installed for compute capability 8.0
|
136
|
+
|
95
137
|
cc_flag = []
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
138
|
+
|
139
|
+
if HIP_BUILD:
|
140
|
+
check_if_hip_home_none(PACKAGE_NAME)
|
141
|
+
|
142
|
+
rocm_home = os.getenv("ROCM_PATH")
|
143
|
+
_, hip_version = get_hip_version(rocm_home)
|
144
|
+
|
145
|
+
|
146
|
+
if HIP_HOME is not None:
|
147
|
+
if hip_version < Version("6.0"):
|
148
|
+
raise RuntimeError(
|
149
|
+
f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. "
|
150
|
+
"Note: make sure HIP has a supported version by running hipcc --version."
|
151
|
+
)
|
152
|
+
if hip_version == Version("6.0"):
|
153
|
+
warnings.warn(
|
154
|
+
f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. "
|
155
|
+
"Refer to the README.md for detailed instructions.",
|
156
|
+
UserWarning
|
157
|
+
)
|
158
|
+
|
159
|
+
cc_flag.append("-DBUILD_PYTHON_PACKAGE")
|
160
|
+
|
161
|
+
else:
|
162
|
+
check_if_cuda_home_none(PACKAGE_NAME)
|
163
|
+
# Check, if CUDA11 is installed for compute capability 8.0
|
164
|
+
|
165
|
+
if CUDA_HOME is not None:
|
166
|
+
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
167
|
+
if bare_metal_version < Version("11.6"):
|
168
|
+
raise RuntimeError(
|
169
|
+
f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
|
170
|
+
"Note: make sure nvcc has a supported version by running nvcc -V."
|
171
|
+
)
|
172
|
+
|
173
|
+
cc_flag.append("-gencode")
|
174
|
+
cc_flag.append("arch=compute_53,code=sm_53")
|
175
|
+
cc_flag.append("-gencode")
|
176
|
+
cc_flag.append("arch=compute_62,code=sm_62")
|
177
|
+
cc_flag.append("-gencode")
|
178
|
+
cc_flag.append("arch=compute_70,code=sm_70")
|
179
|
+
cc_flag.append("-gencode")
|
180
|
+
cc_flag.append("arch=compute_72,code=sm_72")
|
117
181
|
cc_flag.append("-gencode")
|
118
|
-
cc_flag.append("arch=
|
182
|
+
cc_flag.append("arch=compute_80,code=sm_80")
|
183
|
+
cc_flag.append("-gencode")
|
184
|
+
cc_flag.append("arch=compute_87,code=sm_87")
|
185
|
+
if bare_metal_version >= Version("11.8"):
|
186
|
+
cc_flag.append("-gencode")
|
187
|
+
cc_flag.append("arch=compute_90,code=sm_90")
|
119
188
|
|
120
189
|
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
121
190
|
# torch._C._GLIBCXX_USE_CXX11_ABI
|
@@ -123,6 +192,42 @@ if not SKIP_CUDA_BUILD:
|
|
123
192
|
if FORCE_CXX11_ABI:
|
124
193
|
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
125
194
|
|
195
|
+
|
196
|
+
if HIP_BUILD:
|
197
|
+
extra_compile_args = {
|
198
|
+
"cxx": ["-O3", "-std=c++17"],
|
199
|
+
"nvcc": [
|
200
|
+
"-O3",
|
201
|
+
"-std=c++17",
|
202
|
+
f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
|
203
|
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
204
|
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
205
|
+
"-fgpu-flush-denormals-to-zero",
|
206
|
+
]
|
207
|
+
+ cc_flag,
|
208
|
+
}
|
209
|
+
else:
|
210
|
+
extra_compile_args = {
|
211
|
+
"cxx": ["-O3"],
|
212
|
+
"nvcc": append_nvcc_threads(
|
213
|
+
[
|
214
|
+
"-O3",
|
215
|
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
216
|
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
217
|
+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
218
|
+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
219
|
+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
220
|
+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
221
|
+
"--expt-relaxed-constexpr",
|
222
|
+
"--expt-extended-lambda",
|
223
|
+
"--use_fast_math",
|
224
|
+
"--ptxas-options=-v",
|
225
|
+
"-lineinfo",
|
226
|
+
]
|
227
|
+
+ cc_flag
|
228
|
+
),
|
229
|
+
}
|
230
|
+
|
126
231
|
ext_modules.append(
|
127
232
|
CUDAExtension(
|
128
233
|
name="causal_conv1d_cuda",
|
@@ -132,26 +237,7 @@ if not SKIP_CUDA_BUILD:
|
|
132
237
|
"csrc/causal_conv1d_bwd.cu",
|
133
238
|
"csrc/causal_conv1d_update.cu",
|
134
239
|
],
|
135
|
-
extra_compile_args=
|
136
|
-
"cxx": ["-O3"],
|
137
|
-
"nvcc": append_nvcc_threads(
|
138
|
-
[
|
139
|
-
"-O3",
|
140
|
-
"-U__CUDA_NO_HALF_OPERATORS__",
|
141
|
-
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
142
|
-
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
143
|
-
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
144
|
-
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
145
|
-
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
146
|
-
"--expt-relaxed-constexpr",
|
147
|
-
"--expt-extended-lambda",
|
148
|
-
"--use_fast_math",
|
149
|
-
"--ptxas-options=-v",
|
150
|
-
"-lineinfo",
|
151
|
-
]
|
152
|
-
+ cc_flag
|
153
|
-
),
|
154
|
-
},
|
240
|
+
extra_compile_args=extra_compile_args,
|
155
241
|
include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"],
|
156
242
|
)
|
157
243
|
)
|
@@ -169,24 +255,36 @@ def get_package_version():
|
|
169
255
|
|
170
256
|
|
171
257
|
def get_wheel_url():
|
258
|
+
|
172
259
|
# Determine the version numbers that will be used to determine the correct wheel
|
173
|
-
# We're using the CUDA version used to build torch, not the one currently installed
|
174
|
-
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
175
|
-
torch_cuda_version = parse(torch.version.cuda)
|
176
260
|
torch_version_raw = parse(torch.__version__)
|
177
|
-
|
178
|
-
|
179
|
-
|
261
|
+
|
262
|
+
if HIP_BUILD:
|
263
|
+
# We're using the HIP version used to build torch, not the one currently installed
|
264
|
+
torch_hip_version = get_torch_hip_version()
|
265
|
+
hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
|
266
|
+
else:
|
267
|
+
# We're using the CUDA version used to build torch, not the one currently installed
|
268
|
+
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
269
|
+
torch_cuda_version = parse(torch.version.cuda)
|
270
|
+
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
|
271
|
+
# to save CI time. Minor versions should be compatible.
|
272
|
+
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
|
273
|
+
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
274
|
+
|
275
|
+
gpu_compute_version = hip_version if HIP_BUILD else cuda_version
|
276
|
+
cuda_or_hip = "hip" if HIP_BUILD else "cu"
|
277
|
+
|
180
278
|
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
181
279
|
platform_name = get_platform()
|
182
280
|
causal_conv1d_version = get_package_version()
|
183
|
-
|
184
|
-
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
|
281
|
+
|
185
282
|
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
|
186
283
|
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
187
284
|
|
188
285
|
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
189
|
-
wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+
|
286
|
+
wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
287
|
+
|
190
288
|
wheel_url = BASE_WHEEL_URL.format(
|
191
289
|
tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
|
192
290
|
)
|
@@ -260,7 +358,7 @@ setup(
|
|
260
358
|
else {
|
261
359
|
"bdist_wheel": CachedWheelsCommand,
|
262
360
|
},
|
263
|
-
python_requires=">=3.
|
361
|
+
python_requires=">=3.8",
|
264
362
|
install_requires=[
|
265
363
|
"torch",
|
266
364
|
"packaging",
|
File without changes
|
File without changes
|
{causal_conv1d-1.2.2.post1 → causal_conv1d-1.4.0}/causal_conv1d.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|