causal-conv1d 1.1.3.post1__tar.gz → 1.2.0.post1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal_conv1d
3
- Version: 1.1.3.post1
3
+ Version: 1.2.0.post1
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -1,3 +1,3 @@
1
- __version__ = "1.1.3.post1"
1
+ __version__ = "1.2.0.post1"
2
2
 
3
3
  from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
@@ -0,0 +1,213 @@
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ import causal_conv1d_cuda
8
+
9
+
10
+ class CausalConv1dFn(torch.autograd.Function):
11
+ @staticmethod
12
+ def forward(
13
+ ctx,
14
+ x,
15
+ weight,
16
+ bias=None,
17
+ seq_idx=None,
18
+ initial_states=None,
19
+ return_final_states=False,
20
+ final_states_out=None,
21
+ activation=None,
22
+ ):
23
+ if activation not in [None, "silu", "swish"]:
24
+ raise NotImplementedError("activation must be None, silu, or swish")
25
+ if x.stride(2) != 1 and x.stride(1) != 1:
26
+ x = x.contiguous()
27
+ bias = bias.contiguous() if bias is not None else None
28
+ if seq_idx is not None:
29
+ assert (
30
+ initial_states is None
31
+ ), "initial_states must be None if seq_idx is not None"
32
+ assert (
33
+ not return_final_states
34
+ ), "If seq_idx is not None, we don't return final_states_out"
35
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
36
+ if initial_states is not None and (
37
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
38
+ ):
39
+ initial_states = initial_states.contiguous()
40
+ if return_final_states:
41
+ assert (
42
+ x.stride(1) == 1
43
+ ), "Only channel-last layout support returning final_states_out"
44
+ if final_states_out is not None:
45
+ assert (
46
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
47
+ )
48
+ else:
49
+ batch, dim, seqlen = x.shape
50
+ width = weight.shape[1]
51
+ final_states_out = torch.empty(
52
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
53
+ ).transpose(1, 2)
54
+ else:
55
+ final_states_out = None
56
+ ctx.activation = activation in ["silu", "swish"]
57
+ out = causal_conv1d_cuda.causal_conv1d_fwd(
58
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
59
+ )
60
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
61
+ ctx.return_final_states = return_final_states
62
+ ctx.return_dinitial_states = (
63
+ initial_states is not None and initial_states.requires_grad
64
+ )
65
+ return out if not return_final_states else (out, final_states_out)
66
+
67
+ @staticmethod
68
+ def backward(ctx, dout, *args):
69
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
70
+ dfinal_states = args[0] if ctx.return_final_states else None
71
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
72
+ dout = dout.contiguous()
73
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
74
+ # backward of conv1d with the backward of chunk).
75
+ # Here we just pass in None and dx will be allocated in the C++ code.
76
+ dx, dweight, dbias, dinitial_states = causal_conv1d_cuda.causal_conv1d_bwd(
77
+ x,
78
+ weight,
79
+ bias,
80
+ dout,
81
+ seq_idx,
82
+ initial_states,
83
+ dfinal_states,
84
+ None,
85
+ ctx.return_dinitial_states,
86
+ ctx.activation,
87
+ )
88
+ return (
89
+ dx,
90
+ dweight,
91
+ dbias if bias is not None else None,
92
+ None,
93
+ dinitial_states if initial_states is not None else None,
94
+ None,
95
+ None,
96
+ None,
97
+ )
98
+
99
+
100
+ def causal_conv1d_fn(
101
+ x,
102
+ weight,
103
+ bias=None,
104
+ seq_idx=None,
105
+ initial_states=None,
106
+ return_final_states=False,
107
+ final_states_out=None,
108
+ activation=None,
109
+ ):
110
+ """
111
+ x: (batch, dim, seqlen)
112
+ weight: (dim, width)
113
+ bias: (dim,)
114
+ seq_idx: (batch, seqlen)
115
+ initial_states: (batch, dim, width - 1)
116
+ final_states_out: (batch, dim, width - 1), to be written to
117
+ activation: either None or "silu" or "swish"
118
+
119
+ out: (batch, dim, seqlen)
120
+ """
121
+ return CausalConv1dFn.apply(
122
+ x,
123
+ weight,
124
+ bias,
125
+ seq_idx,
126
+ initial_states,
127
+ return_final_states,
128
+ final_states_out,
129
+ activation,
130
+ )
131
+
132
+
133
+ def causal_conv1d_ref(
134
+ x,
135
+ weight,
136
+ bias=None,
137
+ initial_states=None,
138
+ return_final_states=False,
139
+ final_states_out=None,
140
+ activation=None,
141
+ ):
142
+ """
143
+ x: (batch, dim, seqlen)
144
+ weight: (dim, width)
145
+ bias: (dim,)
146
+ initial_states: (batch, dim, width - 1)
147
+ final_states_out: (batch, dim, width - 1)
148
+
149
+ out: (batch, dim, seqlen)
150
+ """
151
+ if activation not in [None, "silu", "swish"]:
152
+ raise NotImplementedError("activation must be None, silu, or swish")
153
+ dtype_in = x.dtype
154
+ x = x.to(weight.dtype)
155
+ seqlen = x.shape[-1]
156
+ dim, width = weight.shape
157
+ if initial_states is None:
158
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
159
+ else:
160
+ x = torch.cat([initial_states, x], dim=-1)
161
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
162
+ out = out[..., :seqlen]
163
+ if return_final_states:
164
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
165
+ dtype_in
166
+ ) # (batch, dim, width - 1)
167
+ if final_states_out is not None:
168
+ final_states_out.copy_(final_states)
169
+ else:
170
+ final_states_out = final_states
171
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
172
+ return out if not return_final_states else (out, final_states_out)
173
+
174
+
175
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
176
+ """
177
+ x: (batch, dim)
178
+ conv_state: (batch, dim, width)
179
+ weight: (dim, width)
180
+ bias: (dim,)
181
+
182
+ out: (batch, dim)
183
+ """
184
+ if activation not in [None, "silu", "swish"]:
185
+ raise NotImplementedError("activation must be None, silu, or swish")
186
+ activation = activation in ["silu", "swish"]
187
+ return causal_conv1d_cuda.causal_conv1d_update(
188
+ x, conv_state, weight, bias, activation
189
+ )
190
+
191
+
192
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
193
+ """
194
+ x: (batch, dim)
195
+ conv_state: (batch, dim, width)
196
+ weight: (dim, width)
197
+ bias: (dim,)
198
+
199
+ out: (batch, dim)
200
+ """
201
+ if activation not in [None, "silu", "swish"]:
202
+ raise NotImplementedError("activation must be None, silu, or swish")
203
+ dtype_in = x.dtype
204
+ batch, dim = x.shape
205
+ width = weight.shape[1]
206
+ assert conv_state.shape == (batch, dim, width)
207
+ assert weight.shape == (dim, width)
208
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
209
+ conv_state[:, :, -1] = x
210
+ out = torch.sum(conv_state * weight, dim=-1) # (B D)
211
+ if bias is not None:
212
+ out += bias
213
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: causal-conv1d
3
- Version: 1.1.3.post1
3
+ Version: 1.2.0.post1
4
4
  Summary: Causal depthwise conv1d in CUDA, with a PyTorch interface
5
5
  Home-page: https://github.com/Dao-AILab/causal-conv1d
6
6
  Author: Tri Dao
@@ -256,7 +256,6 @@ setup(
256
256
  install_requires=[
257
257
  "torch",
258
258
  "packaging",
259
- "buildtools",
260
259
  "ninja",
261
260
  ],
262
261
  )
@@ -1,107 +0,0 @@
1
- # Copyright (c) 2023, Tri Dao.
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
-
7
- import causal_conv1d_cuda
8
-
9
-
10
- class CausalConv1dFn(torch.autograd.Function):
11
-
12
- @staticmethod
13
- def forward(ctx, x, weight, bias=None, seq_idx=None, activation=None):
14
- if activation not in [None, "silu", "swish"]:
15
- raise NotImplementedError("activation must be None, silu, or swish")
16
- if x.stride(2) != 1 and x.stride(1) != 1:
17
- x = x.contiguous()
18
- bias = bias.contiguous() if bias is not None else None
19
- seq_idx = seq_idx.contiguous() if seq_idx is not None else None
20
- ctx.save_for_backward(x, weight, bias, seq_idx)
21
- ctx.activation = activation in ["silu", "swish"]
22
- out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, seq_idx, ctx.activation)
23
- return out
24
-
25
- @staticmethod
26
- def backward(ctx, dout):
27
- x, weight, bias, seq_idx = ctx.saved_tensors
28
- if dout.stride(2) != 1 and dout.stride(1) != 1:
29
- dout = dout.contiguous()
30
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
31
- # backward of conv1d with the backward of chunk).
32
- # Here we just pass in None and dx will be allocated in the C++ code.
33
- dx, dweight, dbias = causal_conv1d_cuda.causal_conv1d_bwd(
34
- x, weight, bias, dout, seq_idx, None, ctx.activation
35
- )
36
- return dx, dweight, dbias if bias is not None else None, None, None
37
-
38
-
39
- def causal_conv1d_fn(x, weight, bias=None, seq_idx=None, activation=None):
40
- """
41
- x: (batch, dim, seqlen)
42
- weight: (dim, width)
43
- bias: (dim,)
44
- seq_idx: (batch, seqlen)
45
- activation: either None or "silu" or "swish"
46
-
47
- out: (batch, dim, seqlen)
48
- """
49
- return CausalConv1dFn.apply(x, weight, bias, seq_idx, activation)
50
-
51
-
52
- def causal_conv1d_ref(x, weight, bias=None, activation=None):
53
- """
54
- x: (batch, dim, seqlen)
55
- weight: (dim, width)
56
- bias: (dim,)
57
-
58
- out: (batch, dim, seqlen)
59
- """
60
- if activation not in [None, "silu", "swish"]:
61
- raise NotImplementedError("activation must be None, silu, or swish")
62
- dtype_in = x.dtype
63
- x = x.to(weight.dtype)
64
- seqlen = x.shape[-1]
65
- dim, width = weight.shape
66
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
67
- out = out[..., :seqlen]
68
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
69
-
70
-
71
- def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None):
72
- """
73
- x: (batch, dim)
74
- conv_state: (batch, dim, width)
75
- weight: (dim, width)
76
- bias: (dim,)
77
-
78
- out: (batch, dim)
79
- """
80
- if activation not in [None, "silu", "swish"]:
81
- raise NotImplementedError("activation must be None, silu, or swish")
82
- activation = activation in ["silu", "swish"]
83
- return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation)
84
-
85
-
86
- def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None):
87
- """
88
- x: (batch, dim)
89
- conv_state: (batch, dim, width)
90
- weight: (dim, width)
91
- bias: (dim,)
92
-
93
- out: (batch, dim)
94
- """
95
- if activation not in [None, "silu", "swish"]:
96
- raise NotImplementedError("activation must be None, silu, or swish")
97
- dtype_in = x.dtype
98
- batch, dim = x.shape
99
- width = weight.shape[1]
100
- assert conv_state.shape == (batch, dim, width)
101
- assert weight.shape == (dim, width)
102
- conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
103
- conv_state[:, :, -1] = x
104
- out = torch.sum(conv_state * weight, dim=-1) # (B D)
105
- if bias is not None:
106
- out += bias
107
- return (out if activation is None else F.silu(out)).to(dtype=dtype_in)