causal-conv1d 1.5.0.post8__tar.gz → 1.5.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. causal_conv1d-1.5.2/MANIFEST.in +3 -0
  2. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/PKG-INFO +1 -1
  3. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d/__init__.py +1 -1
  4. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d/causal_conv1d_interface.py +4 -5
  5. causal_conv1d-1.5.2/causal_conv1d/cpp_functions.py +183 -0
  6. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/PKG-INFO +1 -1
  7. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/SOURCES.txt +11 -1
  8. causal_conv1d-1.5.2/csrc/causal_conv1d.cpp +466 -0
  9. causal_conv1d-1.5.2/csrc/causal_conv1d.h +81 -0
  10. causal_conv1d-1.5.2/csrc/causal_conv1d_bwd.cu +627 -0
  11. causal_conv1d-1.5.2/csrc/causal_conv1d_common.h +98 -0
  12. causal_conv1d-1.5.2/csrc/causal_conv1d_fwd.cu +399 -0
  13. causal_conv1d-1.5.2/csrc/causal_conv1d_update.cu +137 -0
  14. causal_conv1d-1.5.2/csrc/static_switch.h +25 -0
  15. causal_conv1d-1.5.2/pyproject.toml +3 -0
  16. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/setup.py +22 -16
  17. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/AUTHORS +0 -0
  18. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/LICENSE +0 -0
  19. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/README.md +0 -0
  20. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d/causal_conv1d_varlen.py +0 -0
  21. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/dependency_links.txt +0 -0
  22. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/requires.txt +0 -0
  23. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/top_level.txt +0 -0
  24. {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/setup.cfg +0 -0
@@ -0,0 +1,3 @@
1
+ recursive-include csrc *
2
+ recursive-include third_party *
3
+ README.md
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal_conv1d
3
- Version: 1.5.0.post8
3
+ Version: 1.5.2
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -1,3 +1,3 @@
1
- __version__ = "1.5.0.post8"
1
+ __version__ = "1.5.2"
2
2
 
3
3
  from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
@@ -3,8 +3,7 @@
3
3
  import torch
4
4
  import torch.nn.functional as F
5
5
 
6
-
7
- import causal_conv1d_cuda
6
+ from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
8
7
 
9
8
 
10
9
  class CausalConv1dFn(torch.autograd.Function):
@@ -54,7 +53,7 @@ class CausalConv1dFn(torch.autograd.Function):
54
53
  else:
55
54
  final_states_out = None
56
55
  ctx.activation = activation in ["silu", "swish"]
57
- out = causal_conv1d_cuda.causal_conv1d_fwd(
56
+ out = causal_conv1d_fwd_function(
58
57
  x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
59
58
  )
60
59
  ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
@@ -73,7 +72,7 @@ class CausalConv1dFn(torch.autograd.Function):
73
72
  # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
74
73
  # backward of conv1d with the backward of chunk).
75
74
  # Here we just pass in None and dx will be allocated in the C++ code.
76
- dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
75
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
77
76
  x,
78
77
  weight,
79
78
  bias,
@@ -195,7 +194,7 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
195
194
  unsqueeze = x.dim() == 2
196
195
  if unsqueeze:
197
196
  x = x.unsqueeze(-1)
198
- out = causal_conv1d_cuda.causal_conv1d_update(
197
+ out = causal_conv1d_update_function(
199
198
  x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
200
199
  )
201
200
  if unsqueeze:
@@ -0,0 +1,183 @@
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+
5
+ import causal_conv1d_cuda
6
+
7
+
8
+ LIBRARY_NAME = "DaoAILab"
9
+
10
+
11
+ @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_fwd_cpp", mutates_args={"out", "final_states_out"})
12
+ def _causal_conv1d_fwd_cpp(
13
+ x: torch.Tensor,
14
+ weight: torch.Tensor,
15
+ bias: torch.Tensor | None,
16
+ seq_idx: torch.Tensor | None,
17
+ initial_states: torch.Tensor | None,
18
+ out: torch.Tensor,
19
+ final_states_out: torch.Tensor | None,
20
+ silu_activation: bool,
21
+ ) -> None:
22
+ causal_conv1d_cuda.causal_conv1d_fwd(
23
+ x,
24
+ weight,
25
+ bias,
26
+ seq_idx,
27
+ initial_states,
28
+ out,
29
+ final_states_out,
30
+ silu_activation,
31
+ )
32
+
33
+
34
+ @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_bwd_cpp", mutates_args={
35
+ "dfinal_states",
36
+ "dx",
37
+ "dweight",
38
+ "dbias",
39
+ "dinitial_states",
40
+ })
41
+ def _causal_conv1d_bwd_cpp(
42
+ x: torch.Tensor,
43
+ weight: torch.Tensor,
44
+ bias: torch.Tensor | None,
45
+ dout: torch.Tensor,
46
+ seq_idx: torch.Tensor | None,
47
+ initial_states: torch.Tensor | None,
48
+ dfinal_states: torch.Tensor | None,
49
+ dx: torch.Tensor,
50
+ dweight: torch.Tensor,
51
+ dbias: torch.Tensor | None,
52
+ dinitial_states: torch.Tensor,
53
+ silu_activation: bool,
54
+ ) -> None:
55
+ causal_conv1d_cuda.causal_conv1d_bwd(
56
+ x,
57
+ weight,
58
+ bias,
59
+ dout,
60
+ seq_idx,
61
+ initial_states,
62
+ dfinal_states,
63
+ dx,
64
+ dweight,
65
+ dbias,
66
+ dinitial_states,
67
+ silu_activation,
68
+ )
69
+
70
+
71
+ @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_update_cpp", mutates_args={"out", "conv_state"})
72
+ def _causal_conv1d_update_cpp(
73
+ x: torch.Tensor,
74
+ conv_state: torch.Tensor,
75
+ weight: torch.Tensor,
76
+ bias: torch.Tensor | None,
77
+ out: torch.Tensor,
78
+ silu_activation: bool,
79
+ cache_seqlens: torch.Tensor | None,
80
+ conv_state_indices: torch.Tensor | None,
81
+ ) -> None:
82
+ causal_conv1d_cuda.causal_conv1d_update(
83
+ x,
84
+ conv_state,
85
+ weight,
86
+ bias,
87
+ out,
88
+ silu_activation,
89
+ cache_seqlens,
90
+ conv_state_indices
91
+ )
92
+
93
+
94
+ def causal_conv1d_fwd_function(
95
+ x: torch.Tensor,
96
+ weight: torch.Tensor,
97
+ bias: torch.Tensor | None,
98
+ seq_idx: torch.Tensor | None,
99
+ initial_states: torch.Tensor | None,
100
+ final_states_out: torch.Tensor | None,
101
+ silu_activation: bool,
102
+ ) -> torch.Tensor:
103
+ out = torch.empty_like(x)
104
+ _causal_conv1d_fwd_cpp(
105
+ x=x,
106
+ weight=weight,
107
+ bias=bias,
108
+ seq_idx=seq_idx,
109
+ initial_states=initial_states,
110
+ out=out,
111
+ final_states_out=final_states_out,
112
+ silu_activation=silu_activation,
113
+ )
114
+ return out
115
+
116
+
117
+ def causal_conv1d_bwd_function(
118
+ x: torch.Tensor,
119
+ weight: torch.Tensor,
120
+ bias: torch.Tensor | None,
121
+ dout: torch.Tensor,
122
+ seq_idx: torch.Tensor | None,
123
+ initial_states: torch.Tensor | None,
124
+ dfinal_states: torch.Tensor | None,
125
+ dx: torch.Tensor | None,
126
+ return_dinitial_states: torch.Tensor,
127
+ silu_activation: bool,
128
+ ) -> tuple[torch.Tensor | None]:
129
+ batch_size, dim = x.size()[:2]
130
+ width = weight.size(-1)
131
+
132
+ if dx is None:
133
+ dx = torch.empty_like(x)
134
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
135
+ dbias = None
136
+ if bias is not None:
137
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
138
+ dinitial_states = None
139
+ if return_dinitial_states:
140
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
141
+
142
+ _causal_conv1d_bwd_cpp(
143
+ x=x,
144
+ weight=weight,
145
+ bias=bias,
146
+ dout=dout,
147
+ seq_idx=seq_idx,
148
+ initial_states=initial_states,
149
+ dfinal_states=dfinal_states,
150
+ dx=dx,
151
+ dweight=dweight,
152
+ dbias=dbias,
153
+ dinitial_states=dinitial_states,
154
+ silu_activation=silu_activation,
155
+ )
156
+
157
+ dweight = dweight.type_as(weight)
158
+ if dbias is not None:
159
+ dbias = dbias.type_as(bias)
160
+ return dx, dweight, dbias, dinitial_states
161
+
162
+
163
+ def causal_conv1d_update_function(
164
+ x: torch.Tensor,
165
+ conv_state: torch.Tensor,
166
+ weight: torch.Tensor,
167
+ bias: torch.Tensor | None,
168
+ silu_activation: bool,
169
+ cache_seqlens: torch.Tensor | None,
170
+ conv_state_indices: torch.Tensor | None,
171
+ ) -> torch.Tensor:
172
+ out = torch.empty_like(x)
173
+ _causal_conv1d_update_cpp(
174
+ x=x,
175
+ conv_state=conv_state,
176
+ weight=weight,
177
+ bias=bias,
178
+ out=out,
179
+ silu_activation=silu_activation,
180
+ cache_seqlens=cache_seqlens,
181
+ conv_state_indices=conv_state_indices,
182
+ )
183
+ return out
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal-conv1d
3
- Version: 1.5.0.post8
3
+ Version: 1.5.2
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -1,12 +1,22 @@
1
1
  AUTHORS
2
2
  LICENSE
3
+ MANIFEST.in
3
4
  README.md
5
+ pyproject.toml
4
6
  setup.py
5
7
  causal_conv1d/__init__.py
6
8
  causal_conv1d/causal_conv1d_interface.py
7
9
  causal_conv1d/causal_conv1d_varlen.py
10
+ causal_conv1d/cpp_functions.py
8
11
  causal_conv1d.egg-info/PKG-INFO
9
12
  causal_conv1d.egg-info/SOURCES.txt
10
13
  causal_conv1d.egg-info/dependency_links.txt
11
14
  causal_conv1d.egg-info/requires.txt
12
- causal_conv1d.egg-info/top_level.txt
15
+ causal_conv1d.egg-info/top_level.txt
16
+ csrc/causal_conv1d.cpp
17
+ csrc/causal_conv1d.h
18
+ csrc/causal_conv1d_bwd.cu
19
+ csrc/causal_conv1d_common.h
20
+ csrc/causal_conv1d_fwd.cu
21
+ csrc/causal_conv1d_update.cu
22
+ csrc/static_switch.h