causal-conv1d 1.3.0.post1__tar.gz → 1.5.0.post5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/PKG-INFO +5 -5
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/README.md +3 -3
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d/__init__.py +1 -1
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d/causal_conv1d_interface.py +48 -18
- causal_conv1d-1.5.0.post5/causal_conv1d/causal_conv1d_varlen.py +86 -0
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d.egg-info/PKG-INFO +5 -5
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d.egg-info/SOURCES.txt +1 -0
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/setup.py +4 -5
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/AUTHORS +0 -0
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/LICENSE +0 -0
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d.egg-info/dependency_links.txt +0 -0
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d.egg-info/requires.txt +0 -0
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d.egg-info/top_level.txt +0 -0
- {causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: causal_conv1d
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.5.0.post5
|
4
4
|
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
|
5
5
|
Home-page: https://github.com/Dao-AILab/causal-conv1d
|
6
6
|
Author: Tri Dao
|
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
9
9
|
Classifier: License :: OSI Approved :: BSD License
|
10
10
|
Classifier: Operating System :: Unix
|
11
|
-
Requires-Python: >=3.
|
11
|
+
Requires-Python: >=3.9
|
12
12
|
Description-Content-Type: text/markdown
|
13
13
|
License-File: LICENSE
|
14
14
|
License-File: AUTHORS
|
@@ -21,11 +21,11 @@ Features:
|
|
21
21
|
|
22
22
|
## How to use
|
23
23
|
|
24
|
-
```
|
24
|
+
```python
|
25
25
|
from causal_conv1d import causal_conv1d_fn
|
26
26
|
```
|
27
27
|
|
28
|
-
```
|
28
|
+
```python
|
29
29
|
def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
30
30
|
"""
|
31
31
|
x: (batch, dim, seqlen)
|
@@ -38,7 +38,7 @@ def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
|
38
38
|
```
|
39
39
|
|
40
40
|
Equivalent to:
|
41
|
-
```
|
41
|
+
```python
|
42
42
|
import torch.nn.functional as F
|
43
43
|
|
44
44
|
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
@@ -6,11 +6,11 @@ Features:
|
|
6
6
|
|
7
7
|
## How to use
|
8
8
|
|
9
|
-
```
|
9
|
+
```python
|
10
10
|
from causal_conv1d import causal_conv1d_fn
|
11
11
|
```
|
12
12
|
|
13
|
-
```
|
13
|
+
```python
|
14
14
|
def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
15
15
|
"""
|
16
16
|
x: (batch, dim, seqlen)
|
@@ -23,7 +23,7 @@ def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
|
23
23
|
```
|
24
24
|
|
25
25
|
Equivalent to:
|
26
|
-
```
|
26
|
+
```python
|
27
27
|
import torch.nn.functional as F
|
28
28
|
|
29
29
|
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
{causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d/causal_conv1d_interface.py
RENAMED
@@ -172,42 +172,72 @@ def causal_conv1d_ref(
|
|
172
172
|
return out if not return_final_states else (out, final_states_out)
|
173
173
|
|
174
174
|
|
175
|
-
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
|
175
|
+
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
|
176
176
|
"""
|
177
|
-
x: (batch, dim)
|
178
|
-
conv_state: (batch, dim, width
|
177
|
+
x: (batch, dim) or (batch, dim, seqlen)
|
178
|
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
179
179
|
weight: (dim, width)
|
180
180
|
bias: (dim,)
|
181
|
-
|
182
|
-
|
181
|
+
cache_seqlens: (batch,), dtype int32.
|
182
|
+
If not None, the conv_state is treated as a circular buffer.
|
183
|
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
184
|
+
@cache_seqlens % state_len.
|
185
|
+
conv_state_indices: (batch,), dtype int32
|
186
|
+
If None, the conv_state is a larger tensor along the batch dim,
|
187
|
+
and we are selecting the batch coords specified by conv_state_indices.
|
188
|
+
Useful for a continuous batching scenario.
|
189
|
+
|
190
|
+
out: (batch, dim) or (batch, dim, seqlen)
|
183
191
|
"""
|
184
192
|
if activation not in [None, "silu", "swish"]:
|
185
193
|
raise NotImplementedError("activation must be None, silu, or swish")
|
186
194
|
activation = activation in ["silu", "swish"]
|
187
|
-
|
188
|
-
|
195
|
+
unsqueeze = x.dim() == 2
|
196
|
+
if unsqueeze:
|
197
|
+
x = x.unsqueeze(-1)
|
198
|
+
out = causal_conv1d_cuda.causal_conv1d_update(
|
199
|
+
x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
|
189
200
|
)
|
201
|
+
if unsqueeze:
|
202
|
+
out = out.squeeze(-1)
|
203
|
+
return out
|
190
204
|
|
191
205
|
|
192
|
-
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
|
206
|
+
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
193
207
|
"""
|
194
|
-
x: (batch, dim)
|
195
|
-
conv_state: (batch, dim, width
|
208
|
+
x: (batch, dim) or (batch, dim, seqlen)
|
209
|
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
196
210
|
weight: (dim, width)
|
197
211
|
bias: (dim,)
|
212
|
+
cache_seqlens: (batch,), dtype int32.
|
213
|
+
If not None, the conv_state is treated as a circular buffer.
|
214
|
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
215
|
+
@cache_seqlens % state_len before performing the convolution.
|
198
216
|
|
199
|
-
out: (batch, dim)
|
217
|
+
out: (batch, dim) or (batch, dim, seqlen)
|
200
218
|
"""
|
201
219
|
if activation not in [None, "silu", "swish"]:
|
202
220
|
raise NotImplementedError("activation must be None, silu, or swish")
|
203
221
|
dtype_in = x.dtype
|
204
|
-
|
222
|
+
unsqueeze = x.dim() == 2
|
223
|
+
if unsqueeze:
|
224
|
+
x = x.unsqueeze(-1)
|
225
|
+
batch, dim, seqlen = x.shape
|
205
226
|
width = weight.shape[1]
|
206
|
-
|
227
|
+
state_len = conv_state.shape[-1]
|
228
|
+
assert conv_state.shape == (batch, dim, state_len)
|
207
229
|
assert weight.shape == (dim, width)
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
230
|
+
if cache_seqlens is None:
|
231
|
+
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
|
232
|
+
conv_state.copy_(x_new[:, :, -state_len:])
|
233
|
+
else:
|
234
|
+
width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
235
|
+
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
236
|
+
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
237
|
+
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
238
|
+
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
239
|
+
conv_state.scatter_(2, copy_idx, x)
|
240
|
+
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
|
241
|
+
if unsqueeze:
|
242
|
+
out = out.squeeze(-1)
|
213
243
|
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import Tensor
|
3
|
+
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
|
8
|
+
@triton.jit
|
9
|
+
def _causal_conv1d_varlen_states(
|
10
|
+
X,
|
11
|
+
CU_SEQLENS,
|
12
|
+
STATES,
|
13
|
+
state_len,
|
14
|
+
dim,
|
15
|
+
stride_x_seqlen, stride_x_dim,
|
16
|
+
stride_states_batch, stride_states_seqlen, stride_states_dim,
|
17
|
+
BLOCK_M: tl.constexpr,
|
18
|
+
BLOCK_N: tl.constexpr
|
19
|
+
):
|
20
|
+
batch_idx = tl.program_id(2)
|
21
|
+
STATES += batch_idx * stride_states_batch
|
22
|
+
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
|
23
|
+
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
|
24
|
+
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
25
|
+
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
|
26
|
+
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
|
27
|
+
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
|
28
|
+
other=0)
|
29
|
+
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
30
|
+
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
|
31
|
+
x,
|
32
|
+
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
|
33
|
+
|
34
|
+
|
35
|
+
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
36
|
+
"""
|
37
|
+
Forward pass only, does not support backward pass.
|
38
|
+
Parameters:
|
39
|
+
x: (total_tokens, dim)
|
40
|
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
41
|
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
42
|
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
43
|
+
Return:
|
44
|
+
states: (batch, dim, state_len)
|
45
|
+
"""
|
46
|
+
_, dim = x.shape
|
47
|
+
batch = cu_seqlens.shape[0] - 1
|
48
|
+
cu_seqlens = cu_seqlens.contiguous()
|
49
|
+
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
50
|
+
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
|
51
|
+
BLOCK_N = min(triton.next_power_of_2(dim), 256)
|
52
|
+
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
|
53
|
+
with torch.cuda.device(x.device.index):
|
54
|
+
_causal_conv1d_varlen_states[grid](
|
55
|
+
x,
|
56
|
+
cu_seqlens,
|
57
|
+
states,
|
58
|
+
state_len,
|
59
|
+
dim,
|
60
|
+
x.stride(0), x.stride(1),
|
61
|
+
states.stride(0), states.stride(2), states.stride(1),
|
62
|
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
63
|
+
)
|
64
|
+
return states
|
65
|
+
|
66
|
+
|
67
|
+
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
68
|
+
"""
|
69
|
+
Forward pass only, does not support backward pass.
|
70
|
+
Parameters:
|
71
|
+
x: (total_tokens, dim)
|
72
|
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
73
|
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
74
|
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
75
|
+
Return:
|
76
|
+
states: (batch, dim, state_len)
|
77
|
+
"""
|
78
|
+
_, dim = x.shape
|
79
|
+
batch = cu_seqlens.shape[0] - 1
|
80
|
+
cu_seqlens = cu_seqlens.contiguous()
|
81
|
+
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
82
|
+
for i in range(batch):
|
83
|
+
end_idx = cu_seqlens[i + 1]
|
84
|
+
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
|
85
|
+
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
|
86
|
+
return states
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: causal-conv1d
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.5.0.post5
|
4
4
|
Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
|
5
5
|
Home-page: https://github.com/Dao-AILab/causal-conv1d
|
6
6
|
Author: Tri Dao
|
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
9
9
|
Classifier: License :: OSI Approved :: BSD License
|
10
10
|
Classifier: Operating System :: Unix
|
11
|
-
Requires-Python: >=3.
|
11
|
+
Requires-Python: >=3.9
|
12
12
|
Description-Content-Type: text/markdown
|
13
13
|
License-File: LICENSE
|
14
14
|
License-File: AUTHORS
|
@@ -21,11 +21,11 @@ Features:
|
|
21
21
|
|
22
22
|
## How to use
|
23
23
|
|
24
|
-
```
|
24
|
+
```python
|
25
25
|
from causal_conv1d import causal_conv1d_fn
|
26
26
|
```
|
27
27
|
|
28
|
-
```
|
28
|
+
```python
|
29
29
|
def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
30
30
|
"""
|
31
31
|
x: (batch, dim, seqlen)
|
@@ -38,7 +38,7 @@ def causal_conv1d_fn(x, weight, bias=None, activation=None):
|
|
38
38
|
```
|
39
39
|
|
40
40
|
Equivalent to:
|
41
|
-
```
|
41
|
+
```python
|
42
42
|
import torch.nn.functional as F
|
43
43
|
|
44
44
|
F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
|
@@ -202,7 +202,6 @@ if not SKIP_CUDA_BUILD:
|
|
202
202
|
f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
|
203
203
|
"-U__CUDA_NO_HALF_OPERATORS__",
|
204
204
|
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
205
|
-
"-DCK_FMHA_FWD_FAST_EXP2=1",
|
206
205
|
"-fgpu-flush-denormals-to-zero",
|
207
206
|
]
|
208
207
|
+ cc_flag,
|
@@ -268,10 +267,10 @@ def get_wheel_url():
|
|
268
267
|
# We're using the CUDA version used to build torch, not the one currently installed
|
269
268
|
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
270
269
|
torch_cuda_version = parse(torch.version.cuda)
|
271
|
-
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.
|
270
|
+
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.4
|
272
271
|
# to save CI time. Minor versions should be compatible.
|
273
|
-
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.
|
274
|
-
cuda_version = f"{torch_cuda_version.major}
|
272
|
+
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.4")
|
273
|
+
cuda_version = f"{torch_cuda_version.major}"
|
275
274
|
|
276
275
|
gpu_compute_version = hip_version if HIP_BUILD else cuda_version
|
277
276
|
cuda_or_hip = "hip" if HIP_BUILD else "cu"
|
@@ -359,7 +358,7 @@ setup(
|
|
359
358
|
else {
|
360
359
|
"bdist_wheel": CachedWheelsCommand,
|
361
360
|
},
|
362
|
-
python_requires=">=3.
|
361
|
+
python_requires=">=3.9",
|
363
362
|
install_requires=[
|
364
363
|
"torch",
|
365
364
|
"packaging",
|
File without changes
|
File without changes
|
{causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
{causal_conv1d-1.3.0.post1 → causal_conv1d-1.5.0.post5}/causal_conv1d.egg-info/top_level.txt
RENAMED
File without changes
|
File without changes
|