causal-conv1d 1.3.0.post1__tar.gz → 1.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal_conv1d
3
- Version: 1.3.0.post1
3
+ Version: 1.4.0
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: License :: OSI Approved :: BSD License
10
10
  Classifier: Operating System :: Unix
11
- Requires-Python: >=3.7
11
+ Requires-Python: >=3.8
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  License-File: AUTHORS
@@ -1,3 +1,3 @@
1
- __version__ = "1.3.0.post1"
1
+ __version__ = "1.4.0"
2
2
 
3
3
  from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
@@ -172,42 +172,68 @@ def causal_conv1d_ref(
172
172
  return out if not return_final_states else (out, final_states_out)
173
173
 
174
174
 
175
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
175
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
176
176
  """
177
- x: (batch, dim)
178
- conv_state: (batch, dim, width)
177
+ x: (batch, dim) or (batch, dim, seqlen)
178
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
179
179
  weight: (dim, width)
180
180
  bias: (dim,)
181
+ cache_seqlens: (batch,), dtype int32.
182
+ If not None, the conv_state is treated as a circular buffer.
183
+ The conv_state will be updated by copying x to the conv_state starting at the index
184
+ @cache_seqlens % state_len.
181
185
 
182
- out: (batch, dim)
186
+ out: (batch, dim) or (batch, dim, seqlen)
183
187
  """
184
188
  if activation not in [None, "silu", "swish"]:
185
189
  raise NotImplementedError("activation must be None, silu, or swish")
186
190
  activation = activation in ["silu", "swish"]
187
- return causal_conv1d_cuda.causal_conv1d_update(
188
- x, conv_state, weight, bias, activation
191
+ unsqueeze = x.dim() == 2
192
+ if unsqueeze:
193
+ x = x.unsqueeze(-1)
194
+ out = causal_conv1d_cuda.causal_conv1d_update(
195
+ x, conv_state, weight, bias, activation, cache_seqlens
189
196
  )
197
+ if unsqueeze:
198
+ out = out.squeeze(-1)
199
+ return out
190
200
 
191
201
 
192
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
202
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
193
203
  """
194
- x: (batch, dim)
195
- conv_state: (batch, dim, width)
204
+ x: (batch, dim) or (batch, dim, seqlen)
205
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
196
206
  weight: (dim, width)
197
207
  bias: (dim,)
208
+ cache_seqlens: (batch,), dtype int32.
209
+ If not None, the conv_state is treated as a circular buffer.
210
+ The conv_state will be updated by copying x to the conv_state starting at the index
211
+ @cache_seqlens % state_len before performing the convolution.
198
212
 
199
- out: (batch, dim)
213
+ out: (batch, dim) or (batch, dim, seqlen)
200
214
  """
201
215
  if activation not in [None, "silu", "swish"]:
202
216
  raise NotImplementedError("activation must be None, silu, or swish")
203
217
  dtype_in = x.dtype
204
- batch, dim = x.shape
218
+ unsqueeze = x.dim() == 2
219
+ if unsqueeze:
220
+ x = x.unsqueeze(-1)
221
+ batch, dim, seqlen = x.shape
205
222
  width = weight.shape[1]
206
- assert conv_state.shape == (batch, dim, width)
223
+ state_len = conv_state.shape[-1]
224
+ assert conv_state.shape == (batch, dim, state_len)
207
225
  assert weight.shape == (dim, width)
208
- conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
209
- conv_state[:, :, -1] = x
210
- out = torch.sum(conv_state * weight, dim=-1) # (B D)
211
- if bias is not None:
212
- out += bias
226
+ if cache_seqlens is None:
227
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
228
+ conv_state.copy_(x_new[:, :, -state_len:])
229
+ else:
230
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
231
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
232
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
233
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
+ conv_state.scatter_(2, copy_idx, x)
236
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
237
+ if unsqueeze:
238
+ out = out.squeeze(-1)
213
239
  return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@@ -0,0 +1,86 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _causal_conv1d_varlen_states(
10
+ X,
11
+ CU_SEQLENS,
12
+ STATES,
13
+ state_len,
14
+ dim,
15
+ stride_x_seqlen, stride_x_dim,
16
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
17
+ BLOCK_M: tl.constexpr,
18
+ BLOCK_N: tl.constexpr
19
+ ):
20
+ batch_idx = tl.program_id(2)
21
+ STATES += batch_idx * stride_states_batch
22
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
+ other=0)
29
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
+ x,
32
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
+
34
+
35
+ def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
+ """
37
+ Forward pass only, does not support backward pass.
38
+ Parameters:
39
+ x: (total_tokens, dim)
40
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
+ If some of those elements belong to a different sequence, the value of the states will be zero.
43
+ Return:
44
+ states: (batch, dim, state_len)
45
+ """
46
+ _, dim = x.shape
47
+ batch = cu_seqlens.shape[0] - 1
48
+ cu_seqlens = cu_seqlens.contiguous()
49
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
+ with torch.cuda.device(x.device.index):
54
+ _causal_conv1d_varlen_states[grid](
55
+ x,
56
+ cu_seqlens,
57
+ states,
58
+ state_len,
59
+ dim,
60
+ x.stride(0), x.stride(1),
61
+ states.stride(0), states.stride(2), states.stride(1),
62
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
+ )
64
+ return states
65
+
66
+
67
+ def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
+ """
69
+ Forward pass only, does not support backward pass.
70
+ Parameters:
71
+ x: (total_tokens, dim)
72
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
+ If some of those elements belong to a different sequence, the value of the states will be zero.
75
+ Return:
76
+ states: (batch, dim, state_len)
77
+ """
78
+ _, dim = x.shape
79
+ batch = cu_seqlens.shape[0] - 1
80
+ cu_seqlens = cu_seqlens.contiguous()
81
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
+ for i in range(batch):
83
+ end_idx = cu_seqlens[i + 1]
84
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
+ return states
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal-conv1d
3
- Version: 1.3.0.post1
3
+ Version: 1.4.0
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: License :: OSI Approved :: BSD License
10
10
  Classifier: Operating System :: Unix
11
- Requires-Python: >=3.7
11
+ Requires-Python: >=3.8
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  License-File: AUTHORS
@@ -4,6 +4,7 @@ README.md
4
4
  setup.py
5
5
  causal_conv1d/__init__.py
6
6
  causal_conv1d/causal_conv1d_interface.py
7
+ causal_conv1d/causal_conv1d_varlen.py
7
8
  causal_conv1d.egg-info/PKG-INFO
8
9
  causal_conv1d.egg-info/SOURCES.txt
9
10
  causal_conv1d.egg-info/dependency_links.txt
@@ -202,7 +202,6 @@ if not SKIP_CUDA_BUILD:
202
202
  f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}",
203
203
  "-U__CUDA_NO_HALF_OPERATORS__",
204
204
  "-U__CUDA_NO_HALF_CONVERSIONS__",
205
- "-DCK_FMHA_FWD_FAST_EXP2=1",
206
205
  "-fgpu-flush-denormals-to-zero",
207
206
  ]
208
207
  + cc_flag,
@@ -359,7 +358,7 @@ setup(
359
358
  else {
360
359
  "bdist_wheel": CachedWheelsCommand,
361
360
  },
362
- python_requires=">=3.7",
361
+ python_requires=">=3.8",
363
362
  install_requires=[
364
363
  "torch",
365
364
  "packaging",