causal-conv1d 1.4.0__tar.gz → 1.5.0.post5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal_conv1d
3
- Version: 1.4.0
3
+ Version: 1.5.0.post5
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: License :: OSI Approved :: BSD License
10
10
  Classifier: Operating System :: Unix
11
- Requires-Python: >=3.8
11
+ Requires-Python: >=3.9
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  License-File: AUTHORS
@@ -21,11 +21,11 @@ Features:
21
21
 
22
22
  ## How to use
23
23
 
24
- ```
24
+ ```python
25
25
  from causal_conv1d import causal_conv1d_fn
26
26
  ```
27
27
 
28
- ```
28
+ ```python
29
29
  def causal_conv1d_fn(x, weight, bias=None, activation=None):
30
30
  """
31
31
  x: (batch, dim, seqlen)
@@ -38,7 +38,7 @@ def causal_conv1d_fn(x, weight, bias=None, activation=None):
38
38
  ```
39
39
 
40
40
  Equivalent to:
41
- ```
41
+ ```python
42
42
  import torch.nn.functional as F
43
43
 
44
44
  F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
@@ -6,11 +6,11 @@ Features:
6
6
 
7
7
  ## How to use
8
8
 
9
- ```
9
+ ```python
10
10
  from causal_conv1d import causal_conv1d_fn
11
11
  ```
12
12
 
13
- ```
13
+ ```python
14
14
  def causal_conv1d_fn(x, weight, bias=None, activation=None):
15
15
  """
16
16
  x: (batch, dim, seqlen)
@@ -23,7 +23,7 @@ def causal_conv1d_fn(x, weight, bias=None, activation=None):
23
23
  ```
24
24
 
25
25
  Equivalent to:
26
- ```
26
+ ```python
27
27
  import torch.nn.functional as F
28
28
 
29
29
  F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
@@ -1,3 +1,3 @@
1
- __version__ = "1.4.0"
1
+ __version__ = "1.5.0.post5"
2
2
 
3
3
  from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
@@ -172,7 +172,7 @@ def causal_conv1d_ref(
172
172
  return out if not return_final_states else (out, final_states_out)
173
173
 
174
174
 
175
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
175
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
176
176
  """
177
177
  x: (batch, dim) or (batch, dim, seqlen)
178
178
  conv_state: (batch, dim, state_len), where state_len >= width - 1
@@ -182,6 +182,10 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
182
182
  If not None, the conv_state is treated as a circular buffer.
183
183
  The conv_state will be updated by copying x to the conv_state starting at the index
184
184
  @cache_seqlens % state_len.
185
+ conv_state_indices: (batch,), dtype int32
186
+ If None, the conv_state is a larger tensor along the batch dim,
187
+ and we are selecting the batch coords specified by conv_state_indices.
188
+ Useful for a continuous batching scenario.
185
189
 
186
190
  out: (batch, dim) or (batch, dim, seqlen)
187
191
  """
@@ -192,7 +196,7 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
192
196
  if unsqueeze:
193
197
  x = x.unsqueeze(-1)
194
198
  out = causal_conv1d_cuda.causal_conv1d_update(
195
- x, conv_state, weight, bias, activation, cache_seqlens
199
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
196
200
  )
197
201
  if unsqueeze:
198
202
  out = out.squeeze(-1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal-conv1d
3
- Version: 1.4.0
3
+ Version: 1.5.0.post5
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -8,7 +8,7 @@ Author-email: tri@tridao.me
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: License :: OSI Approved :: BSD License
10
10
  Classifier: Operating System :: Unix
11
- Requires-Python: >=3.8
11
+ Requires-Python: >=3.9
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
14
  License-File: AUTHORS
@@ -21,11 +21,11 @@ Features:
21
21
 
22
22
  ## How to use
23
23
 
24
- ```
24
+ ```python
25
25
  from causal_conv1d import causal_conv1d_fn
26
26
  ```
27
27
 
28
- ```
28
+ ```python
29
29
  def causal_conv1d_fn(x, weight, bias=None, activation=None):
30
30
  """
31
31
  x: (batch, dim, seqlen)
@@ -38,7 +38,7 @@ def causal_conv1d_fn(x, weight, bias=None, activation=None):
38
38
  ```
39
39
 
40
40
  Equivalent to:
41
- ```
41
+ ```python
42
42
  import torch.nn.functional as F
43
43
 
44
44
  F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen]
@@ -267,10 +267,10 @@ def get_wheel_url():
267
267
  # We're using the CUDA version used to build torch, not the one currently installed
268
268
  # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
269
269
  torch_cuda_version = parse(torch.version.cuda)
270
- # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
270
+ # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.4
271
271
  # to save CI time. Minor versions should be compatible.
272
- torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
273
- cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
272
+ torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.4")
273
+ cuda_version = f"{torch_cuda_version.major}"
274
274
 
275
275
  gpu_compute_version = hip_version if HIP_BUILD else cuda_version
276
276
  cuda_or_hip = "hip" if HIP_BUILD else "cu"
@@ -358,7 +358,7 @@ setup(
358
358
  else {
359
359
  "bdist_wheel": CachedWheelsCommand,
360
360
  },
361
- python_requires=">=3.8",
361
+ python_requires=">=3.9",
362
362
  install_requires=[
363
363
  "torch",
364
364
  "packaging",