causal-conv1d 1.5.0.post8__tar.gz → 1.5.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causal_conv1d-1.5.2/MANIFEST.in +3 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/PKG-INFO +1 -1
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d/__init__.py +1 -1
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d/causal_conv1d_interface.py +4 -5
- causal_conv1d-1.5.2/causal_conv1d/cpp_functions.py +183 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/PKG-INFO +1 -1
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/SOURCES.txt +11 -1
- causal_conv1d-1.5.2/csrc/causal_conv1d.cpp +466 -0
- causal_conv1d-1.5.2/csrc/causal_conv1d.h +81 -0
- causal_conv1d-1.5.2/csrc/causal_conv1d_bwd.cu +627 -0
- causal_conv1d-1.5.2/csrc/causal_conv1d_common.h +98 -0
- causal_conv1d-1.5.2/csrc/causal_conv1d_fwd.cu +399 -0
- causal_conv1d-1.5.2/csrc/causal_conv1d_update.cu +137 -0
- causal_conv1d-1.5.2/csrc/static_switch.h +25 -0
- causal_conv1d-1.5.2/pyproject.toml +3 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/setup.py +22 -16
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/AUTHORS +0 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/LICENSE +0 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/README.md +0 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d/causal_conv1d_varlen.py +0 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/dependency_links.txt +0 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/requires.txt +0 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/causal_conv1d.egg-info/top_level.txt +0 -0
- {causal_conv1d-1.5.0.post8 → causal_conv1d-1.5.2}/setup.cfg +0 -0
@@ -3,8 +3,7 @@
|
|
3
3
|
import torch
|
4
4
|
import torch.nn.functional as F
|
5
5
|
|
6
|
-
|
7
|
-
import causal_conv1d_cuda
|
6
|
+
from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
|
8
7
|
|
9
8
|
|
10
9
|
class CausalConv1dFn(torch.autograd.Function):
|
@@ -54,7 +53,7 @@ class CausalConv1dFn(torch.autograd.Function):
|
|
54
53
|
else:
|
55
54
|
final_states_out = None
|
56
55
|
ctx.activation = activation in ["silu", "swish"]
|
57
|
-
out =
|
56
|
+
out = causal_conv1d_fwd_function(
|
58
57
|
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
|
59
58
|
)
|
60
59
|
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
|
@@ -73,7 +72,7 @@ class CausalConv1dFn(torch.autograd.Function):
|
|
73
72
|
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
74
73
|
# backward of conv1d with the backward of chunk).
|
75
74
|
# Here we just pass in None and dx will be allocated in the C++ code.
|
76
|
-
dx, dweight, dbias, dinitial_states =
|
75
|
+
dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
|
77
76
|
x,
|
78
77
|
weight,
|
79
78
|
bias,
|
@@ -195,7 +194,7 @@ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cach
|
|
195
194
|
unsqueeze = x.dim() == 2
|
196
195
|
if unsqueeze:
|
197
196
|
x = x.unsqueeze(-1)
|
198
|
-
out =
|
197
|
+
out = causal_conv1d_update_function(
|
199
198
|
x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
|
200
199
|
)
|
201
200
|
if unsqueeze:
|
@@ -0,0 +1,183 @@
|
|
1
|
+
# Copyright (c) 2024, Tri Dao.
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
import causal_conv1d_cuda
|
6
|
+
|
7
|
+
|
8
|
+
LIBRARY_NAME = "DaoAILab"
|
9
|
+
|
10
|
+
|
11
|
+
@torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_fwd_cpp", mutates_args={"out", "final_states_out"})
|
12
|
+
def _causal_conv1d_fwd_cpp(
|
13
|
+
x: torch.Tensor,
|
14
|
+
weight: torch.Tensor,
|
15
|
+
bias: torch.Tensor | None,
|
16
|
+
seq_idx: torch.Tensor | None,
|
17
|
+
initial_states: torch.Tensor | None,
|
18
|
+
out: torch.Tensor,
|
19
|
+
final_states_out: torch.Tensor | None,
|
20
|
+
silu_activation: bool,
|
21
|
+
) -> None:
|
22
|
+
causal_conv1d_cuda.causal_conv1d_fwd(
|
23
|
+
x,
|
24
|
+
weight,
|
25
|
+
bias,
|
26
|
+
seq_idx,
|
27
|
+
initial_states,
|
28
|
+
out,
|
29
|
+
final_states_out,
|
30
|
+
silu_activation,
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
@torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_bwd_cpp", mutates_args={
|
35
|
+
"dfinal_states",
|
36
|
+
"dx",
|
37
|
+
"dweight",
|
38
|
+
"dbias",
|
39
|
+
"dinitial_states",
|
40
|
+
})
|
41
|
+
def _causal_conv1d_bwd_cpp(
|
42
|
+
x: torch.Tensor,
|
43
|
+
weight: torch.Tensor,
|
44
|
+
bias: torch.Tensor | None,
|
45
|
+
dout: torch.Tensor,
|
46
|
+
seq_idx: torch.Tensor | None,
|
47
|
+
initial_states: torch.Tensor | None,
|
48
|
+
dfinal_states: torch.Tensor | None,
|
49
|
+
dx: torch.Tensor,
|
50
|
+
dweight: torch.Tensor,
|
51
|
+
dbias: torch.Tensor | None,
|
52
|
+
dinitial_states: torch.Tensor,
|
53
|
+
silu_activation: bool,
|
54
|
+
) -> None:
|
55
|
+
causal_conv1d_cuda.causal_conv1d_bwd(
|
56
|
+
x,
|
57
|
+
weight,
|
58
|
+
bias,
|
59
|
+
dout,
|
60
|
+
seq_idx,
|
61
|
+
initial_states,
|
62
|
+
dfinal_states,
|
63
|
+
dx,
|
64
|
+
dweight,
|
65
|
+
dbias,
|
66
|
+
dinitial_states,
|
67
|
+
silu_activation,
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
@torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_update_cpp", mutates_args={"out", "conv_state"})
|
72
|
+
def _causal_conv1d_update_cpp(
|
73
|
+
x: torch.Tensor,
|
74
|
+
conv_state: torch.Tensor,
|
75
|
+
weight: torch.Tensor,
|
76
|
+
bias: torch.Tensor | None,
|
77
|
+
out: torch.Tensor,
|
78
|
+
silu_activation: bool,
|
79
|
+
cache_seqlens: torch.Tensor | None,
|
80
|
+
conv_state_indices: torch.Tensor | None,
|
81
|
+
) -> None:
|
82
|
+
causal_conv1d_cuda.causal_conv1d_update(
|
83
|
+
x,
|
84
|
+
conv_state,
|
85
|
+
weight,
|
86
|
+
bias,
|
87
|
+
out,
|
88
|
+
silu_activation,
|
89
|
+
cache_seqlens,
|
90
|
+
conv_state_indices
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
def causal_conv1d_fwd_function(
|
95
|
+
x: torch.Tensor,
|
96
|
+
weight: torch.Tensor,
|
97
|
+
bias: torch.Tensor | None,
|
98
|
+
seq_idx: torch.Tensor | None,
|
99
|
+
initial_states: torch.Tensor | None,
|
100
|
+
final_states_out: torch.Tensor | None,
|
101
|
+
silu_activation: bool,
|
102
|
+
) -> torch.Tensor:
|
103
|
+
out = torch.empty_like(x)
|
104
|
+
_causal_conv1d_fwd_cpp(
|
105
|
+
x=x,
|
106
|
+
weight=weight,
|
107
|
+
bias=bias,
|
108
|
+
seq_idx=seq_idx,
|
109
|
+
initial_states=initial_states,
|
110
|
+
out=out,
|
111
|
+
final_states_out=final_states_out,
|
112
|
+
silu_activation=silu_activation,
|
113
|
+
)
|
114
|
+
return out
|
115
|
+
|
116
|
+
|
117
|
+
def causal_conv1d_bwd_function(
|
118
|
+
x: torch.Tensor,
|
119
|
+
weight: torch.Tensor,
|
120
|
+
bias: torch.Tensor | None,
|
121
|
+
dout: torch.Tensor,
|
122
|
+
seq_idx: torch.Tensor | None,
|
123
|
+
initial_states: torch.Tensor | None,
|
124
|
+
dfinal_states: torch.Tensor | None,
|
125
|
+
dx: torch.Tensor | None,
|
126
|
+
return_dinitial_states: torch.Tensor,
|
127
|
+
silu_activation: bool,
|
128
|
+
) -> tuple[torch.Tensor | None]:
|
129
|
+
batch_size, dim = x.size()[:2]
|
130
|
+
width = weight.size(-1)
|
131
|
+
|
132
|
+
if dx is None:
|
133
|
+
dx = torch.empty_like(x)
|
134
|
+
dweight = torch.zeros_like(weight, dtype=torch.float32)
|
135
|
+
dbias = None
|
136
|
+
if bias is not None:
|
137
|
+
dbias = torch.zeros_like(bias, dtype=torch.float32)
|
138
|
+
dinitial_states = None
|
139
|
+
if return_dinitial_states:
|
140
|
+
dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
|
141
|
+
|
142
|
+
_causal_conv1d_bwd_cpp(
|
143
|
+
x=x,
|
144
|
+
weight=weight,
|
145
|
+
bias=bias,
|
146
|
+
dout=dout,
|
147
|
+
seq_idx=seq_idx,
|
148
|
+
initial_states=initial_states,
|
149
|
+
dfinal_states=dfinal_states,
|
150
|
+
dx=dx,
|
151
|
+
dweight=dweight,
|
152
|
+
dbias=dbias,
|
153
|
+
dinitial_states=dinitial_states,
|
154
|
+
silu_activation=silu_activation,
|
155
|
+
)
|
156
|
+
|
157
|
+
dweight = dweight.type_as(weight)
|
158
|
+
if dbias is not None:
|
159
|
+
dbias = dbias.type_as(bias)
|
160
|
+
return dx, dweight, dbias, dinitial_states
|
161
|
+
|
162
|
+
|
163
|
+
def causal_conv1d_update_function(
|
164
|
+
x: torch.Tensor,
|
165
|
+
conv_state: torch.Tensor,
|
166
|
+
weight: torch.Tensor,
|
167
|
+
bias: torch.Tensor | None,
|
168
|
+
silu_activation: bool,
|
169
|
+
cache_seqlens: torch.Tensor | None,
|
170
|
+
conv_state_indices: torch.Tensor | None,
|
171
|
+
) -> torch.Tensor:
|
172
|
+
out = torch.empty_like(x)
|
173
|
+
_causal_conv1d_update_cpp(
|
174
|
+
x=x,
|
175
|
+
conv_state=conv_state,
|
176
|
+
weight=weight,
|
177
|
+
bias=bias,
|
178
|
+
out=out,
|
179
|
+
silu_activation=silu_activation,
|
180
|
+
cache_seqlens=cache_seqlens,
|
181
|
+
conv_state_indices=conv_state_indices,
|
182
|
+
)
|
183
|
+
return out
|
@@ -1,12 +1,22 @@
|
|
1
1
|
AUTHORS
|
2
2
|
LICENSE
|
3
|
+
MANIFEST.in
|
3
4
|
README.md
|
5
|
+
pyproject.toml
|
4
6
|
setup.py
|
5
7
|
causal_conv1d/__init__.py
|
6
8
|
causal_conv1d/causal_conv1d_interface.py
|
7
9
|
causal_conv1d/causal_conv1d_varlen.py
|
10
|
+
causal_conv1d/cpp_functions.py
|
8
11
|
causal_conv1d.egg-info/PKG-INFO
|
9
12
|
causal_conv1d.egg-info/SOURCES.txt
|
10
13
|
causal_conv1d.egg-info/dependency_links.txt
|
11
14
|
causal_conv1d.egg-info/requires.txt
|
12
|
-
causal_conv1d.egg-info/top_level.txt
|
15
|
+
causal_conv1d.egg-info/top_level.txt
|
16
|
+
csrc/causal_conv1d.cpp
|
17
|
+
csrc/causal_conv1d.h
|
18
|
+
csrc/causal_conv1d_bwd.cu
|
19
|
+
csrc/causal_conv1d_common.h
|
20
|
+
csrc/causal_conv1d_fwd.cu
|
21
|
+
csrc/causal_conv1d_update.cu
|
22
|
+
csrc/static_switch.h
|