causal-conv1d 1.5.0.post5__tar.gz → 1.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. causal_conv1d-1.5.1/MANIFEST.in +4 -0
  2. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/PKG-INFO +1 -1
  3. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/causal_conv1d/__init__.py +1 -1
  4. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/causal_conv1d/causal_conv1d_interface.py +4 -5
  5. causal_conv1d-1.5.1/causal_conv1d/cpp_functions.py +183 -0
  6. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/causal_conv1d.egg-info/PKG-INFO +1 -1
  7. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/causal_conv1d.egg-info/SOURCES.txt +10 -1
  8. causal_conv1d-1.5.1/csrc/causal_conv1d.h +81 -0
  9. causal_conv1d-1.5.1/csrc/causal_conv1d_bwd.cu +627 -0
  10. causal_conv1d-1.5.1/csrc/causal_conv1d_common.h +98 -0
  11. causal_conv1d-1.5.1/csrc/causal_conv1d_fwd.cu +399 -0
  12. causal_conv1d-1.5.1/csrc/causal_conv1d_update.cu +137 -0
  13. causal_conv1d-1.5.1/csrc/static_switch.h +25 -0
  14. causal_conv1d-1.5.1/pyproject.toml +3 -0
  15. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/setup.py +23 -17
  16. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/AUTHORS +0 -0
  17. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/LICENSE +0 -0
  18. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/README.md +0 -0
  19. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/causal_conv1d/causal_conv1d_varlen.py +0 -0
  20. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/causal_conv1d.egg-info/dependency_links.txt +0 -0
  21. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/causal_conv1d.egg-info/requires.txt +0 -0
  22. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/causal_conv1d.egg-info/top_level.txt +0 -0
  23. {causal_conv1d-1.5.0.post5 → causal_conv1d-1.5.1}/setup.cfg +0 -0
@@ -0,0 +1,4 @@
1
+ include csrc/*.h
2
+ include csrc/*.cu
3
+ recursive-include third_party *
4
+ README.md
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal_conv1d
3
- Version: 1.5.0.post5
3
+ Version: 1.5.1
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -1,3 +1,3 @@
1
- __version__ = "1.5.0.post5"
1
+ __version__ = "1.5.1"
2
2
 
3
3
  from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
@@ -3,8 +3,7 @@
3
3
  import torch
4
4
  import torch.nn.functional as F
5
5
 
6
-
7
- import causal_conv1d_cuda
6
+ from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
8
7
 
9
8
 
10
9
  class CausalConv1dFn(torch.autograd.Function):
@@ -54,7 +53,7 @@ class CausalConv1dFn(torch.autograd.Function):
54
53
  else:
55
54
  final_states_out = None
56
55
  ctx.activation = activation in ["silu", "swish"]
57
- out = causal_conv1d_cuda.causal_conv1d_fwd(
56
+ out = causal_conv1d_fwd_function(
58
57
  x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
59
58
  )
60
59
  ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
@@ -73,7 +72,7 @@ class CausalConv1dFn(torch.autograd.Function):
73
72
  # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
74
73
  # backward of conv1d with the backward of chunk).
75
74
  # Here we just pass in None and dx will be allocated in the C++ code.
76
- dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
75
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
77
76
  x,
78
77
  weight,
79
78
  bias,
@@ -195,7 +194,7 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
195
194
  unsqueeze = x.dim() == 2
196
195
  if unsqueeze:
197
196
  x = x.unsqueeze(-1)
198
- out = causal_conv1d_cuda.causal_conv1d_update(
197
+ out = causal_conv1d_update_function(
199
198
  x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
200
199
  )
201
200
  if unsqueeze:
@@ -0,0 +1,183 @@
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+
5
+ import causal_conv1d_cuda
6
+
7
+
8
+ LIBRARY_NAME = "DaoAILab"
9
+
10
+
11
+ @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_fwd_cpp", mutates_args={"out", "final_states_out"})
12
+ def _causal_conv1d_fwd_cpp(
13
+ x: torch.Tensor,
14
+ weight: torch.Tensor,
15
+ bias: torch.Tensor | None,
16
+ seq_idx: torch.Tensor | None,
17
+ initial_states: torch.Tensor | None,
18
+ out: torch.Tensor,
19
+ final_states_out: torch.Tensor | None,
20
+ silu_activation: bool,
21
+ ) -> None:
22
+ causal_conv1d_cuda.causal_conv1d_fwd(
23
+ x,
24
+ weight,
25
+ bias,
26
+ seq_idx,
27
+ initial_states,
28
+ out,
29
+ final_states_out,
30
+ silu_activation,
31
+ )
32
+
33
+
34
+ @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_bwd_cpp", mutates_args={
35
+ "dfinal_states",
36
+ "dx",
37
+ "dweight",
38
+ "dbias",
39
+ "dinitial_states",
40
+ })
41
+ def _causal_conv1d_bwd_cpp(
42
+ x: torch.Tensor,
43
+ weight: torch.Tensor,
44
+ bias: torch.Tensor | None,
45
+ dout: torch.Tensor,
46
+ seq_idx: torch.Tensor | None,
47
+ initial_states: torch.Tensor | None,
48
+ dfinal_states: torch.Tensor | None,
49
+ dx: torch.Tensor,
50
+ dweight: torch.Tensor,
51
+ dbias: torch.Tensor | None,
52
+ dinitial_states: torch.Tensor,
53
+ silu_activation: bool,
54
+ ) -> None:
55
+ causal_conv1d_cuda.causal_conv1d_bwd(
56
+ x,
57
+ weight,
58
+ bias,
59
+ dout,
60
+ seq_idx,
61
+ initial_states,
62
+ dfinal_states,
63
+ dx,
64
+ dweight,
65
+ dbias,
66
+ dinitial_states,
67
+ silu_activation,
68
+ )
69
+
70
+
71
+ @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_update_cpp", mutates_args={"out", "conv_state"})
72
+ def _causal_conv1d_update_cpp(
73
+ x: torch.Tensor,
74
+ conv_state: torch.Tensor,
75
+ weight: torch.Tensor,
76
+ bias: torch.Tensor | None,
77
+ out: torch.Tensor,
78
+ silu_activation: bool,
79
+ cache_seqlens: torch.Tensor | None,
80
+ conv_state_indices: torch.Tensor | None,
81
+ ) -> None:
82
+ causal_conv1d_cuda.causal_conv1d_update(
83
+ x,
84
+ conv_state,
85
+ weight,
86
+ bias,
87
+ out,
88
+ silu_activation,
89
+ cache_seqlens,
90
+ conv_state_indices
91
+ )
92
+
93
+
94
+ def causal_conv1d_fwd_function(
95
+ x: torch.Tensor,
96
+ weight: torch.Tensor,
97
+ bias: torch.Tensor | None,
98
+ seq_idx: torch.Tensor | None,
99
+ initial_states: torch.Tensor | None,
100
+ final_states_out: torch.Tensor | None,
101
+ silu_activation: bool,
102
+ ) -> torch.Tensor:
103
+ out = torch.empty_like(x)
104
+ _causal_conv1d_fwd_cpp(
105
+ x=x,
106
+ weight=weight,
107
+ bias=bias,
108
+ seq_idx=seq_idx,
109
+ initial_states=initial_states,
110
+ out=out,
111
+ final_states_out=final_states_out,
112
+ silu_activation=silu_activation,
113
+ )
114
+ return out
115
+
116
+
117
+ def causal_conv1d_bwd_function(
118
+ x: torch.Tensor,
119
+ weight: torch.Tensor,
120
+ bias: torch.Tensor | None,
121
+ dout: torch.Tensor,
122
+ seq_idx: torch.Tensor | None,
123
+ initial_states: torch.Tensor | None,
124
+ dfinal_states: torch.Tensor | None,
125
+ dx: torch.Tensor | None,
126
+ return_dinitial_states: torch.Tensor,
127
+ silu_activation: bool,
128
+ ) -> tuple[torch.Tensor | None]:
129
+ batch_size, dim = x.size()[:2]
130
+ width = weight.size(-1)
131
+
132
+ if dx is None:
133
+ dx = torch.empty_like(x)
134
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
135
+ dbias = None
136
+ if bias is not None:
137
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
138
+ dinitial_states = None
139
+ if return_dinitial_states:
140
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
141
+
142
+ _causal_conv1d_bwd_cpp(
143
+ x=x,
144
+ weight=weight,
145
+ bias=bias,
146
+ dout=dout,
147
+ seq_idx=seq_idx,
148
+ initial_states=initial_states,
149
+ dfinal_states=dfinal_states,
150
+ dx=dx,
151
+ dweight=dweight,
152
+ dbias=dbias,
153
+ dinitial_states=dinitial_states,
154
+ silu_activation=silu_activation,
155
+ )
156
+
157
+ dweight = dweight.type_as(weight)
158
+ if dbias is not None:
159
+ dbias = dbias.type_as(bias)
160
+ return dx, dweight, dbias, dinitial_states
161
+
162
+
163
+ def causal_conv1d_update_function(
164
+ x: torch.Tensor,
165
+ conv_state: torch.Tensor,
166
+ weight: torch.Tensor,
167
+ bias: torch.Tensor | None,
168
+ silu_activation: bool,
169
+ cache_seqlens: torch.Tensor | None,
170
+ conv_state_indices: torch.Tensor | None,
171
+ ) -> torch.Tensor:
172
+ out = torch.empty_like(x)
173
+ _causal_conv1d_update_cpp(
174
+ x=x,
175
+ conv_state=conv_state,
176
+ weight=weight,
177
+ bias=bias,
178
+ out=out,
179
+ silu_activation=silu_activation,
180
+ cache_seqlens=cache_seqlens,
181
+ conv_state_indices=conv_state_indices,
182
+ )
183
+ return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal-conv1d
3
- Version: 1.5.0.post5
3
+ Version: 1.5.1
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -1,12 +1,21 @@
1
1
  AUTHORS
2
2
  LICENSE
3
+ MANIFEST.in
3
4
  README.md
5
+ pyproject.toml
4
6
  setup.py
5
7
  causal_conv1d/__init__.py
6
8
  causal_conv1d/causal_conv1d_interface.py
7
9
  causal_conv1d/causal_conv1d_varlen.py
10
+ causal_conv1d/cpp_functions.py
8
11
  causal_conv1d.egg-info/PKG-INFO
9
12
  causal_conv1d.egg-info/SOURCES.txt
10
13
  causal_conv1d.egg-info/dependency_links.txt
11
14
  causal_conv1d.egg-info/requires.txt
12
- causal_conv1d.egg-info/top_level.txt
15
+ causal_conv1d.egg-info/top_level.txt
16
+ csrc/causal_conv1d.h
17
+ csrc/causal_conv1d_bwd.cu
18
+ csrc/causal_conv1d_common.h
19
+ csrc/causal_conv1d_fwd.cu
20
+ csrc/causal_conv1d_update.cu
21
+ csrc/static_switch.h
@@ -0,0 +1,81 @@
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct ConvParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, dim, seqlen, width;
13
+ bool silu_activation;
14
+
15
+ index_t x_batch_stride;
16
+ index_t x_c_stride;
17
+ index_t x_l_stride;
18
+ index_t weight_c_stride;
19
+ index_t weight_width_stride;
20
+ index_t out_batch_stride;
21
+ index_t out_c_stride;
22
+ index_t out_l_stride;
23
+
24
+ int conv_state_len;
25
+ index_t conv_state_batch_stride;
26
+ index_t conv_state_c_stride;
27
+ index_t conv_state_l_stride;
28
+
29
+ // Common data pointers.
30
+ void *__restrict__ x_ptr;
31
+ void *__restrict__ weight_ptr;
32
+ void *__restrict__ bias_ptr;
33
+ void *__restrict__ out_ptr;
34
+
35
+ void *__restrict__ conv_state_ptr;
36
+ int32_t *__restrict__ cache_seqlens;
37
+
38
+ // Only used if the elements of the batch are gathered from a larger buffer,
39
+ // which may happen for continuous batching.
40
+ int32_t *__restrict__ conv_state_indices_ptr;
41
+
42
+ void *__restrict__ seq_idx_ptr;
43
+
44
+ // No __restrict__ since initial_states could be the same as final_states.
45
+ void * initial_states_ptr;
46
+ index_t initial_states_batch_stride;
47
+ index_t initial_states_l_stride;
48
+ index_t initial_states_c_stride;
49
+
50
+ void * final_states_ptr;
51
+ index_t final_states_batch_stride;
52
+ index_t final_states_l_stride;
53
+ index_t final_states_c_stride;
54
+ };
55
+
56
+ struct ConvParamsBwd: public ConvParamsBase {
57
+ index_t dx_batch_stride;
58
+ index_t dx_c_stride;
59
+ index_t dx_l_stride;
60
+ index_t dweight_c_stride;
61
+ index_t dweight_width_stride;
62
+ index_t dout_batch_stride;
63
+ index_t dout_c_stride;
64
+ index_t dout_l_stride;
65
+
66
+ // Common data pointers.
67
+ void *__restrict__ dx_ptr;
68
+ void *__restrict__ dweight_ptr;
69
+ void *__restrict__ dbias_ptr;
70
+ void *__restrict__ dout_ptr;
71
+
72
+ void * dinitial_states_ptr;
73
+ index_t dinitial_states_batch_stride;
74
+ index_t dinitial_states_l_stride;
75
+ index_t dinitial_states_c_stride;
76
+
77
+ void * dfinal_states_ptr;
78
+ index_t dfinal_states_batch_stride;
79
+ index_t dfinal_states_l_stride;
80
+ index_t dfinal_states_c_stride;
81
+ };