causal-conv1d 1.1.3.post1__tar.gz → 1.2.0.post1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/PKG-INFO +1 -1
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/causal_conv1d/__init__.py +1 -1
- causal_conv1d-1.2.0.post1/causal_conv1d/causal_conv1d_interface.py +213 -0
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/causal_conv1d.egg-info/PKG-INFO +1 -1
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/causal_conv1d.egg-info/requires.txt +0 -1
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/setup.py +0 -1
- causal_conv1d-1.1.3.post1/causal_conv1d/causal_conv1d_interface.py +0 -107
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/AUTHORS +0 -0
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/LICENSE +0 -0
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/README.md +0 -0
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/causal_conv1d.egg-info/SOURCES.txt +0 -0
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/causal_conv1d.egg-info/dependency_links.txt +0 -0
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/causal_conv1d.egg-info/top_level.txt +0 -0
- {causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/setup.cfg +0 -0
@@ -0,0 +1,213 @@
|
|
1
|
+
# Copyright (c) 2024, Tri Dao.
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
|
7
|
+
import causal_conv1d_cuda
|
8
|
+
|
9
|
+
|
10
|
+
class CausalConv1dFn(torch.autograd.Function):
|
11
|
+
@staticmethod
|
12
|
+
def forward(
|
13
|
+
ctx,
|
14
|
+
x,
|
15
|
+
weight,
|
16
|
+
bias=None,
|
17
|
+
seq_idx=None,
|
18
|
+
initial_states=None,
|
19
|
+
return_final_states=False,
|
20
|
+
final_states_out=None,
|
21
|
+
activation=None,
|
22
|
+
):
|
23
|
+
if activation not in [None, "silu", "swish"]:
|
24
|
+
raise NotImplementedError("activation must be None, silu, or swish")
|
25
|
+
if x.stride(2) != 1 and x.stride(1) != 1:
|
26
|
+
x = x.contiguous()
|
27
|
+
bias = bias.contiguous() if bias is not None else None
|
28
|
+
if seq_idx is not None:
|
29
|
+
assert (
|
30
|
+
initial_states is None
|
31
|
+
), "initial_states must be None if seq_idx is not None"
|
32
|
+
assert (
|
33
|
+
not return_final_states
|
34
|
+
), "If seq_idx is not None, we don't return final_states_out"
|
35
|
+
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
36
|
+
if initial_states is not None and (
|
37
|
+
initial_states.stride(2) != 1 and initial_states.stride(1) != 1
|
38
|
+
):
|
39
|
+
initial_states = initial_states.contiguous()
|
40
|
+
if return_final_states:
|
41
|
+
assert (
|
42
|
+
x.stride(1) == 1
|
43
|
+
), "Only channel-last layout support returning final_states_out"
|
44
|
+
if final_states_out is not None:
|
45
|
+
assert (
|
46
|
+
final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
|
47
|
+
)
|
48
|
+
else:
|
49
|
+
batch, dim, seqlen = x.shape
|
50
|
+
width = weight.shape[1]
|
51
|
+
final_states_out = torch.empty(
|
52
|
+
batch, width - 1, dim, device=x.device, dtype=x.dtype
|
53
|
+
).transpose(1, 2)
|
54
|
+
else:
|
55
|
+
final_states_out = None
|
56
|
+
ctx.activation = activation in ["silu", "swish"]
|
57
|
+
out = causal_conv1d_cuda.causal_conv1d_fwd(
|
58
|
+
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
|
59
|
+
)
|
60
|
+
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
|
61
|
+
ctx.return_final_states = return_final_states
|
62
|
+
ctx.return_dinitial_states = (
|
63
|
+
initial_states is not None and initial_states.requires_grad
|
64
|
+
)
|
65
|
+
return out if not return_final_states else (out, final_states_out)
|
66
|
+
|
67
|
+
@staticmethod
|
68
|
+
def backward(ctx, dout, *args):
|
69
|
+
x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
|
70
|
+
dfinal_states = args[0] if ctx.return_final_states else None
|
71
|
+
if dout.stride(2) != 1 and dout.stride(1) != 1:
|
72
|
+
dout = dout.contiguous()
|
73
|
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
74
|
+
# backward of conv1d with the backward of chunk).
|
75
|
+
# Here we just pass in None and dx will be allocated in the C++ code.
|
76
|
+
dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
|
77
|
+
x,
|
78
|
+
weight,
|
79
|
+
bias,
|
80
|
+
dout,
|
81
|
+
seq_idx,
|
82
|
+
initial_states,
|
83
|
+
dfinal_states,
|
84
|
+
None,
|
85
|
+
ctx.return_dinitial_states,
|
86
|
+
ctx.activation,
|
87
|
+
)
|
88
|
+
return (
|
89
|
+
dx,
|
90
|
+
dweight,
|
91
|
+
dbias if bias is not None else None,
|
92
|
+
None,
|
93
|
+
dinitial_states if initial_states is not None else None,
|
94
|
+
None,
|
95
|
+
None,
|
96
|
+
None,
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def causal_conv1d_fn(
|
101
|
+
x,
|
102
|
+
weight,
|
103
|
+
bias=None,
|
104
|
+
seq_idx=None,
|
105
|
+
initial_states=None,
|
106
|
+
return_final_states=False,
|
107
|
+
final_states_out=None,
|
108
|
+
activation=None,
|
109
|
+
):
|
110
|
+
"""
|
111
|
+
x: (batch, dim, seqlen)
|
112
|
+
weight: (dim, width)
|
113
|
+
bias: (dim,)
|
114
|
+
seq_idx: (batch, seqlen)
|
115
|
+
initial_states: (batch, dim, width - 1)
|
116
|
+
final_states_out: (batch, dim, width - 1), to be written to
|
117
|
+
activation: either None or "silu" or "swish"
|
118
|
+
|
119
|
+
out: (batch, dim, seqlen)
|
120
|
+
"""
|
121
|
+
return CausalConv1dFn.apply(
|
122
|
+
x,
|
123
|
+
weight,
|
124
|
+
bias,
|
125
|
+
seq_idx,
|
126
|
+
initial_states,
|
127
|
+
return_final_states,
|
128
|
+
final_states_out,
|
129
|
+
activation,
|
130
|
+
)
|
131
|
+
|
132
|
+
|
133
|
+
def causal_conv1d_ref(
|
134
|
+
x,
|
135
|
+
weight,
|
136
|
+
bias=None,
|
137
|
+
initial_states=None,
|
138
|
+
return_final_states=False,
|
139
|
+
final_states_out=None,
|
140
|
+
activation=None,
|
141
|
+
):
|
142
|
+
"""
|
143
|
+
x: (batch, dim, seqlen)
|
144
|
+
weight: (dim, width)
|
145
|
+
bias: (dim,)
|
146
|
+
initial_states: (batch, dim, width - 1)
|
147
|
+
final_states_out: (batch, dim, width - 1)
|
148
|
+
|
149
|
+
out: (batch, dim, seqlen)
|
150
|
+
"""
|
151
|
+
if activation not in [None, "silu", "swish"]:
|
152
|
+
raise NotImplementedError("activation must be None, silu, or swish")
|
153
|
+
dtype_in = x.dtype
|
154
|
+
x = x.to(weight.dtype)
|
155
|
+
seqlen = x.shape[-1]
|
156
|
+
dim, width = weight.shape
|
157
|
+
if initial_states is None:
|
158
|
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
159
|
+
else:
|
160
|
+
x = torch.cat([initial_states, x], dim=-1)
|
161
|
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
162
|
+
out = out[..., :seqlen]
|
163
|
+
if return_final_states:
|
164
|
+
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
165
|
+
dtype_in
|
166
|
+
) # (batch, dim, width - 1)
|
167
|
+
if final_states_out is not None:
|
168
|
+
final_states_out.copy_(final_states)
|
169
|
+
else:
|
170
|
+
final_states_out = final_states
|
171
|
+
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
172
|
+
return out if not return_final_states else (out, final_states_out)
|
173
|
+
|
174
|
+
|
175
|
+
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
|
176
|
+
"""
|
177
|
+
x: (batch, dim)
|
178
|
+
conv_state: (batch, dim, width)
|
179
|
+
weight: (dim, width)
|
180
|
+
bias: (dim,)
|
181
|
+
|
182
|
+
out: (batch, dim)
|
183
|
+
"""
|
184
|
+
if activation not in [None, "silu", "swish"]:
|
185
|
+
raise NotImplementedError("activation must be None, silu, or swish")
|
186
|
+
activation = activation in ["silu", "swish"]
|
187
|
+
return causal_conv1d_cuda.causal_conv1d_update(
|
188
|
+
x, conv_state, weight, bias, activation
|
189
|
+
)
|
190
|
+
|
191
|
+
|
192
|
+
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
|
193
|
+
"""
|
194
|
+
x: (batch, dim)
|
195
|
+
conv_state: (batch, dim, width)
|
196
|
+
weight: (dim, width)
|
197
|
+
bias: (dim,)
|
198
|
+
|
199
|
+
out: (batch, dim)
|
200
|
+
"""
|
201
|
+
if activation not in [None, "silu", "swish"]:
|
202
|
+
raise NotImplementedError("activation must be None, silu, or swish")
|
203
|
+
dtype_in = x.dtype
|
204
|
+
batch, dim = x.shape
|
205
|
+
width = weight.shape[1]
|
206
|
+
assert conv_state.shape == (batch, dim, width)
|
207
|
+
assert weight.shape == (dim, width)
|
208
|
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
209
|
+
conv_state[:, :, -1] = x
|
210
|
+
out = torch.sum(conv_state * weight, dim=-1) # (B D)
|
211
|
+
if bias is not None:
|
212
|
+
out += bias
|
213
|
+
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
@@ -1,107 +0,0 @@
|
|
1
|
-
# Copyright (c) 2023, Tri Dao.
|
2
|
-
|
3
|
-
import torch
|
4
|
-
import torch.nn.functional as F
|
5
|
-
|
6
|
-
|
7
|
-
import causal_conv1d_cuda
|
8
|
-
|
9
|
-
|
10
|
-
class CausalConv1dFn(torch.autograd.Function):
|
11
|
-
|
12
|
-
@staticmethod
|
13
|
-
def forward(ctx, x, weight, bias=None, seq_idx=None, activation=None):
|
14
|
-
if activation not in [None, "silu", "swish"]:
|
15
|
-
raise NotImplementedError("activation must be None, silu, or swish")
|
16
|
-
if x.stride(2) != 1 and x.stride(1) != 1:
|
17
|
-
x = x.contiguous()
|
18
|
-
bias = bias.contiguous() if bias is not None else None
|
19
|
-
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
20
|
-
ctx.save_for_backward(x, weight, bias, seq_idx)
|
21
|
-
ctx.activation = activation in ["silu", "swish"]
|
22
|
-
out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, seq_idx, ctx.activation)
|
23
|
-
return out
|
24
|
-
|
25
|
-
@staticmethod
|
26
|
-
def backward(ctx, dout):
|
27
|
-
x, weight, bias, seq_idx = ctx.saved_tensors
|
28
|
-
if dout.stride(2) != 1 and dout.stride(1) != 1:
|
29
|
-
dout = dout.contiguous()
|
30
|
-
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
31
|
-
# backward of conv1d with the backward of chunk).
|
32
|
-
# Here we just pass in None and dx will be allocated in the C++ code.
|
33
|
-
dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
|
34
|
-
x, weight, bias, dout, seq_idx, None, ctx.activation
|
35
|
-
)
|
36
|
-
return dx, dweight, dbias if bias is not None else None, None, None
|
37
|
-
|
38
|
-
|
39
|
-
def causal_conv1d_fn(x, weight, bias=None, seq_idx=None, activation=None):
|
40
|
-
"""
|
41
|
-
x: (batch, dim, seqlen)
|
42
|
-
weight: (dim, width)
|
43
|
-
bias: (dim,)
|
44
|
-
seq_idx: (batch, seqlen)
|
45
|
-
activation: either None or "silu" or "swish"
|
46
|
-
|
47
|
-
out: (batch, dim, seqlen)
|
48
|
-
"""
|
49
|
-
return CausalConv1dFn.apply(x, weight, bias, seq_idx, activation)
|
50
|
-
|
51
|
-
|
52
|
-
def causal_conv1d_ref(x, weight, bias=None, activation=None):
|
53
|
-
"""
|
54
|
-
x: (batch, dim, seqlen)
|
55
|
-
weight: (dim, width)
|
56
|
-
bias: (dim,)
|
57
|
-
|
58
|
-
out: (batch, dim, seqlen)
|
59
|
-
"""
|
60
|
-
if activation not in [None, "silu", "swish"]:
|
61
|
-
raise NotImplementedError("activation must be None, silu, or swish")
|
62
|
-
dtype_in = x.dtype
|
63
|
-
x = x.to(weight.dtype)
|
64
|
-
seqlen = x.shape[-1]
|
65
|
-
dim, width = weight.shape
|
66
|
-
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
67
|
-
out = out[..., :seqlen]
|
68
|
-
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
69
|
-
|
70
|
-
|
71
|
-
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
|
72
|
-
"""
|
73
|
-
x: (batch, dim)
|
74
|
-
conv_state: (batch, dim, width)
|
75
|
-
weight: (dim, width)
|
76
|
-
bias: (dim,)
|
77
|
-
|
78
|
-
out: (batch, dim)
|
79
|
-
"""
|
80
|
-
if activation not in [None, "silu", "swish"]:
|
81
|
-
raise NotImplementedError("activation must be None, silu, or swish")
|
82
|
-
activation = activation in ["silu", "swish"]
|
83
|
-
return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
|
84
|
-
|
85
|
-
|
86
|
-
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
|
87
|
-
"""
|
88
|
-
x: (batch, dim)
|
89
|
-
conv_state: (batch, dim, width)
|
90
|
-
weight: (dim, width)
|
91
|
-
bias: (dim,)
|
92
|
-
|
93
|
-
out: (batch, dim)
|
94
|
-
"""
|
95
|
-
if activation not in [None, "silu", "swish"]:
|
96
|
-
raise NotImplementedError("activation must be None, silu, or swish")
|
97
|
-
dtype_in = x.dtype
|
98
|
-
batch, dim = x.shape
|
99
|
-
width = weight.shape[1]
|
100
|
-
assert conv_state.shape == (batch, dim, width)
|
101
|
-
assert weight.shape == (dim, width)
|
102
|
-
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
103
|
-
conv_state[:, :, -1] = x
|
104
|
-
out = torch.sum(conv_state * weight, dim=-1) # (B D)
|
105
|
-
if bias is not None:
|
106
|
-
out += bias
|
107
|
-
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/causal_conv1d.egg-info/dependency_links.txt
RENAMED
File without changes
|
{causal_conv1d-1.1.3.post1 → causal_conv1d-1.2.0.post1}/causal_conv1d.egg-info/top_level.txt
RENAMED
File without changes
|
File without changes
|