jaxonlayers 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxonlayers-0.1.0.dist-info → jaxonlayers-0.1.1.dist-info}/METADATA +2 -2
- jaxonlayers-0.1.1.dist-info/RECORD +5 -0
- jaxonlayers/functions/__init__.py +0 -27
- jaxonlayers/functions/attention.py +0 -374
- jaxonlayers/functions/initialization.py +0 -34
- jaxonlayers/functions/masking.py +0 -58
- jaxonlayers/functions/normalization.py +0 -9
- jaxonlayers/functions/regularization.py +0 -47
- jaxonlayers/functions/state_space.py +0 -30
- jaxonlayers/functions/utils.py +0 -9
- jaxonlayers/layers/__init__.py +0 -16
- jaxonlayers/layers/abstract.py +0 -16
- jaxonlayers/layers/attention.py +0 -237
- jaxonlayers/layers/convolution.py +0 -100
- jaxonlayers/layers/normalization.py +0 -246
- jaxonlayers/layers/regularization.py +0 -20
- jaxonlayers/layers/sequential.py +0 -79
- jaxonlayers/layers/state_space.py +0 -75
- jaxonlayers-0.1.0.dist-info/RECORD +0 -21
- {jaxonlayers-0.1.0.dist-info → jaxonlayers-0.1.1.dist-info}/WHEEL +0 -0
- {jaxonlayers-0.1.0.dist-info → jaxonlayers-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jaxonlayers
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: Add your description here
|
|
5
5
|
Requires-Python: >=3.13
|
|
6
6
|
Description-Content-Type: text/markdown
|
|
7
7
|
Requires-Dist: beartype>=0.21.0
|
|
8
8
|
Requires-Dist: equinox>=0.13.0
|
|
9
|
-
Requires-Dist: jax>=0.
|
|
9
|
+
Requires-Dist: jax>=0.8.0
|
|
10
10
|
Requires-Dist: jaxtyping>=0.3.2
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
jaxonlayers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
jaxonlayers-0.1.1.dist-info/METADATA,sha256=aLnCVoHFmkHYSU5QiDRhYhuYBN4IcZDXAmpuYm-asfk,275
|
|
3
|
+
jaxonlayers-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
4
|
+
jaxonlayers-0.1.1.dist-info/top_level.txt,sha256=n5UHFDErh3dJY77ypkEKlwFOQffKBnpGH9nUrwCinto,12
|
|
5
|
+
jaxonlayers-0.1.1.dist-info/RECORD,,
|
|
@@ -1,27 +0,0 @@
|
|
|
1
|
-
from .attention import multi_head_attention_forward, shifted_window_attention
|
|
2
|
-
from .initialization import kaiming_init_conv2d
|
|
3
|
-
from .masking import (
|
|
4
|
-
build_attention_mask,
|
|
5
|
-
canonical_attn_mask,
|
|
6
|
-
canonical_key_padding_mask,
|
|
7
|
-
canonical_mask,
|
|
8
|
-
)
|
|
9
|
-
from .regularization import dropout, stochastic_depth
|
|
10
|
-
from .state_space import selective_scan
|
|
11
|
-
from .utils import (
|
|
12
|
-
default_floating_dtype,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
__all__ = [
|
|
16
|
-
"multi_head_attention_forward",
|
|
17
|
-
"kaiming_init_conv2d",
|
|
18
|
-
"build_attention_mask",
|
|
19
|
-
"canonical_attn_mask",
|
|
20
|
-
"canonical_key_padding_mask",
|
|
21
|
-
"canonical_mask",
|
|
22
|
-
"stochastic_depth",
|
|
23
|
-
"selective_scan",
|
|
24
|
-
"dropout",
|
|
25
|
-
"default_floating_dtype",
|
|
26
|
-
"shifted_window_attention",
|
|
27
|
-
]
|
|
@@ -1,374 +0,0 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
|
|
3
|
-
import equinox as eqx
|
|
4
|
-
import jax
|
|
5
|
-
import jax.numpy as jnp
|
|
6
|
-
from beartype.typing import Any
|
|
7
|
-
from jaxtyping import Array, Bool, Float, PRNGKeyArray
|
|
8
|
-
|
|
9
|
-
from jaxonlayers.functions.normalization import normalize
|
|
10
|
-
from jaxonlayers.functions.regularization import dropout as dropout_fn
|
|
11
|
-
from jaxonlayers.functions.utils import default_floating_dtype
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def multi_head_attention_forward(
|
|
15
|
-
query: Float[Array, "tgt_len d_model"],
|
|
16
|
-
key: Float[Array, "src_len d_model"],
|
|
17
|
-
value: Float[Array, "src_len d_model"],
|
|
18
|
-
embed_dim_to_check: int,
|
|
19
|
-
num_heads: int,
|
|
20
|
-
in_proj_weight: Float[Array, "3*d_model d_model"] | None = None,
|
|
21
|
-
in_proj_bias: Float[Array, "3*d_model"] | None = None,
|
|
22
|
-
bias_k: Float[Array, "1 d_model"] | None = None,
|
|
23
|
-
bias_v: Float[Array, "1 d_model"] | None = None,
|
|
24
|
-
add_zero_attn: bool = False,
|
|
25
|
-
dropout_p: float = 0.0,
|
|
26
|
-
out_proj_weight: Float[Array, "d_model d_model"] | None = None,
|
|
27
|
-
out_proj_bias: Float[Array, "d_model"] | None = None,
|
|
28
|
-
inference: bool = False,
|
|
29
|
-
key_padding_mask: Float[Array, "src_len"] | Bool[Array, "src_len"] | None = None,
|
|
30
|
-
attn_mask: Float[Array, "tgt_len src_len"] | None = None,
|
|
31
|
-
need_weights: bool = True,
|
|
32
|
-
use_separate_proj_weight: bool = False,
|
|
33
|
-
q_proj_weight: Float[Array, "d_model d_model"] | None = None,
|
|
34
|
-
k_proj_weight: Float[Array, "d_model d_model"] | None = None,
|
|
35
|
-
v_proj_weight: Float[Array, "d_model d_model"] | None = None,
|
|
36
|
-
static_k: Float[Array, "src_len d_model"] | None = None,
|
|
37
|
-
static_v: Float[Array, "src_len d_model"] | None = None,
|
|
38
|
-
average_attn_weights: bool = True,
|
|
39
|
-
is_causal: bool = False,
|
|
40
|
-
dropout_key: PRNGKeyArray | None = None,
|
|
41
|
-
) -> tuple[
|
|
42
|
-
Float[Array, "tgt_len d_model"],
|
|
43
|
-
Float[Array, "num_heads tgt_len src_len"]
|
|
44
|
-
| Float[Array, "tgt_len src_len"]
|
|
45
|
-
| Float[Array, "tgt_len src_len+1"]
|
|
46
|
-
| None,
|
|
47
|
-
]:
|
|
48
|
-
tgt_len, d_model = query.shape
|
|
49
|
-
src_len, k_dim = key.shape
|
|
50
|
-
value_len, v_dim = value.shape
|
|
51
|
-
|
|
52
|
-
assert d_model == k_dim == v_dim == embed_dim_to_check, (
|
|
53
|
-
"Embedding dimensions must match"
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
assert src_len == value_len, "Key and value must have the same sequence length"
|
|
57
|
-
|
|
58
|
-
head_dim = d_model // num_heads
|
|
59
|
-
assert head_dim * num_heads == d_model, "embed_dim must be divisible by num_heads"
|
|
60
|
-
|
|
61
|
-
if dropout_p > 0.0:
|
|
62
|
-
assert dropout_key is not None, (
|
|
63
|
-
"dropout_key must be provided if dropout_p > 0.0"
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
if use_separate_proj_weight:
|
|
67
|
-
# When using separate projection weights for q, k, v
|
|
68
|
-
assert q_proj_weight is not None, (
|
|
69
|
-
"q_proj_weight should not be None when use_separate_proj_weight=True"
|
|
70
|
-
)
|
|
71
|
-
assert k_proj_weight is not None, (
|
|
72
|
-
"k_proj_weight should not be None when use_separate_proj_weight=True"
|
|
73
|
-
)
|
|
74
|
-
assert v_proj_weight is not None, (
|
|
75
|
-
"v_proj_weight should not be None when use_separate_proj_weight=True"
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
q = query @ q_proj_weight.T
|
|
79
|
-
|
|
80
|
-
if static_k is None:
|
|
81
|
-
k = key @ k_proj_weight.T
|
|
82
|
-
else:
|
|
83
|
-
k = static_k
|
|
84
|
-
src_len, _ = k.shape
|
|
85
|
-
|
|
86
|
-
if static_v is None:
|
|
87
|
-
v = value @ v_proj_weight.T
|
|
88
|
-
else:
|
|
89
|
-
v = static_v
|
|
90
|
-
value_len, _ = v.shape
|
|
91
|
-
|
|
92
|
-
if in_proj_bias is not None:
|
|
93
|
-
q_bias, k_bias, v_bias = jnp.split(in_proj_bias, 3)
|
|
94
|
-
q = q + q_bias
|
|
95
|
-
k = k + k_bias
|
|
96
|
-
v = v + v_bias
|
|
97
|
-
|
|
98
|
-
else:
|
|
99
|
-
assert in_proj_weight is not None, (
|
|
100
|
-
"in_proj_weight should not be None when use_separate_proj_weight=False"
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
q_proj_weight_part, k_proj_weight_part, v_proj_weight_part = jnp.split(
|
|
104
|
-
in_proj_weight, 3
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
q = query @ q_proj_weight_part.T
|
|
108
|
-
|
|
109
|
-
if static_k is None:
|
|
110
|
-
k = key @ k_proj_weight_part.T
|
|
111
|
-
else:
|
|
112
|
-
k = static_k
|
|
113
|
-
src_len, _ = static_k.shape
|
|
114
|
-
|
|
115
|
-
if static_v is None:
|
|
116
|
-
v = value @ v_proj_weight_part.T
|
|
117
|
-
else:
|
|
118
|
-
v = static_v
|
|
119
|
-
value_len, _ = static_v.shape
|
|
120
|
-
|
|
121
|
-
if in_proj_bias is not None:
|
|
122
|
-
q_bias, k_bias, v_bias = jnp.split(in_proj_bias, 3)
|
|
123
|
-
q = q + q_bias
|
|
124
|
-
k = k + k_bias
|
|
125
|
-
v = v + v_bias
|
|
126
|
-
|
|
127
|
-
assert src_len == value_len
|
|
128
|
-
|
|
129
|
-
q = q.reshape(tgt_len, num_heads, head_dim)
|
|
130
|
-
k = k.reshape(src_len, num_heads, head_dim)
|
|
131
|
-
v = v.reshape(src_len, num_heads, head_dim)
|
|
132
|
-
|
|
133
|
-
if add_zero_attn:
|
|
134
|
-
zero_attn_shape = (1, num_heads, head_dim)
|
|
135
|
-
k_zeros = jnp.zeros(zero_attn_shape)
|
|
136
|
-
v_zeros = jnp.zeros(zero_attn_shape)
|
|
137
|
-
|
|
138
|
-
k = jnp.concatenate([k, k_zeros], axis=0)
|
|
139
|
-
v = jnp.concatenate([v, v_zeros], axis=0)
|
|
140
|
-
|
|
141
|
-
src_len += 1
|
|
142
|
-
value_len += 1
|
|
143
|
-
|
|
144
|
-
if bias_k is not None and bias_v is not None:
|
|
145
|
-
bias_k = bias_k.reshape(1, num_heads, head_dim)
|
|
146
|
-
bias_v = bias_v.reshape(1, num_heads, head_dim)
|
|
147
|
-
|
|
148
|
-
k = jnp.concatenate([k, bias_k], axis=0)
|
|
149
|
-
v = jnp.concatenate([v, bias_v], axis=0)
|
|
150
|
-
|
|
151
|
-
src_len += 1
|
|
152
|
-
value_len += 1
|
|
153
|
-
|
|
154
|
-
assert src_len == value_len
|
|
155
|
-
|
|
156
|
-
# [tgt_len, num_heads, head_dim] → [num_heads, tgt_len, head_dim]
|
|
157
|
-
q = jnp.transpose(q, (1, 0, 2))
|
|
158
|
-
|
|
159
|
-
# [src_len, num_heads, head_dim] → [num_heads, src_len, head_dim]
|
|
160
|
-
k = jnp.transpose(k, (1, 0, 2))
|
|
161
|
-
v = jnp.transpose(v, (1, 0, 2))
|
|
162
|
-
|
|
163
|
-
scale = jnp.sqrt(head_dim)
|
|
164
|
-
attn_output_weights = jnp.matmul(q, jnp.transpose(k, (0, 2, 1))) / scale
|
|
165
|
-
|
|
166
|
-
if key_padding_mask is not None:
|
|
167
|
-
padding_mask = key_padding_mask.reshape(1, 1, src_len)
|
|
168
|
-
padding_mask = jnp.repeat(padding_mask, num_heads, axis=0)
|
|
169
|
-
padding_mask = jnp.repeat(padding_mask, tgt_len, axis=1)
|
|
170
|
-
attn_output_weights = jnp.where(
|
|
171
|
-
padding_mask, float("-inf"), attn_output_weights
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
if attn_mask is not None:
|
|
175
|
-
# [tgt_len, src_len] -> [num_heads, tgt_len, src_len]
|
|
176
|
-
mask = attn_mask.reshape(1, tgt_len, src_len)
|
|
177
|
-
mask = jnp.repeat(mask, num_heads, axis=0)
|
|
178
|
-
attn_output_weights = attn_output_weights + mask
|
|
179
|
-
|
|
180
|
-
if is_causal:
|
|
181
|
-
causal_mask = jnp.triu(jnp.ones((tgt_len, src_len)), k=1)
|
|
182
|
-
causal_mask = (causal_mask == 1).reshape(1, tgt_len, src_len)
|
|
183
|
-
causal_mask = jnp.repeat(causal_mask, num_heads, axis=0)
|
|
184
|
-
attn_output_weights = jnp.where(causal_mask, float("-inf"), attn_output_weights)
|
|
185
|
-
|
|
186
|
-
# [num_heads, tgt_len, src_len]
|
|
187
|
-
attn_output_weights = jax.nn.softmax(attn_output_weights, axis=-1)
|
|
188
|
-
|
|
189
|
-
if dropout_p > 0.0 and not inference:
|
|
190
|
-
assert dropout_key is not None, (
|
|
191
|
-
"dropout_key required because dropout_p > 0.0 and training"
|
|
192
|
-
)
|
|
193
|
-
dropout_mask = jax.random.bernoulli(
|
|
194
|
-
dropout_key, 1 - dropout_p, attn_output_weights.shape
|
|
195
|
-
)
|
|
196
|
-
scale = 1.0 / (1.0 - dropout_p)
|
|
197
|
-
attn_output_weights = attn_output_weights * dropout_mask * scale
|
|
198
|
-
|
|
199
|
-
attn_output = jnp.matmul(attn_output_weights, v)
|
|
200
|
-
attn_output = jnp.transpose(attn_output, (1, 0, 2))
|
|
201
|
-
attn_output = attn_output.reshape(tgt_len, d_model)
|
|
202
|
-
|
|
203
|
-
assert out_proj_weight is not None, "out_proj_weight must be provided"
|
|
204
|
-
attn_output = attn_output @ out_proj_weight.T
|
|
205
|
-
|
|
206
|
-
if out_proj_bias is not None:
|
|
207
|
-
attn_output = attn_output + out_proj_bias
|
|
208
|
-
|
|
209
|
-
if need_weights:
|
|
210
|
-
if average_attn_weights:
|
|
211
|
-
attn_output_weights = attn_output_weights.mean(axis=0)
|
|
212
|
-
return attn_output, attn_output_weights
|
|
213
|
-
else:
|
|
214
|
-
return attn_output, None
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
def create_attn_mask(pad_H, pad_W, window_size, shift_size, dtype: Any | None = None):
|
|
218
|
-
if dtype is None:
|
|
219
|
-
dtype = default_floating_dtype()
|
|
220
|
-
assert dtype is not None
|
|
221
|
-
h_boundaries = jnp.array([pad_H - window_size[0], pad_H - shift_size[0]])
|
|
222
|
-
w_boundaries = jnp.array([pad_W - window_size[1], pad_W - shift_size[1]])
|
|
223
|
-
|
|
224
|
-
h_boundaries = jnp.sort(h_boundaries)
|
|
225
|
-
w_boundaries = jnp.sort(w_boundaries)
|
|
226
|
-
|
|
227
|
-
ii, jj = jnp.indices((pad_H, pad_W)) # ii for rows, jj for columns
|
|
228
|
-
|
|
229
|
-
row_region_idx = jnp.searchsorted(h_boundaries, ii, side="right")
|
|
230
|
-
col_region_idx = jnp.searchsorted(w_boundaries, jj, side="right")
|
|
231
|
-
|
|
232
|
-
num_col_regions = len(w_boundaries) + 1
|
|
233
|
-
attn_mask = row_region_idx * num_col_regions + col_region_idx
|
|
234
|
-
|
|
235
|
-
return attn_mask.astype(dtype)
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
def shifted_window_attention(
|
|
239
|
-
x: Float[Array, "H W C"],
|
|
240
|
-
qkv_weight: Float[Array, "in_dim out_dim"],
|
|
241
|
-
proj_weight: Float[Array, "out_dim out_dim"],
|
|
242
|
-
relative_position_bias: Array,
|
|
243
|
-
window_size: list[int],
|
|
244
|
-
num_heads: int,
|
|
245
|
-
shift_size: list[int],
|
|
246
|
-
attention_dropout: float = 0.0,
|
|
247
|
-
dropout: float = 0.0,
|
|
248
|
-
qkv_bias: Array | None = None,
|
|
249
|
-
proj_bias: Array | None = None,
|
|
250
|
-
logit_scale: Array | None = None,
|
|
251
|
-
inference: bool = False,
|
|
252
|
-
key: PRNGKeyArray | None = None,
|
|
253
|
-
) -> Float[Array, "H W C"]:
|
|
254
|
-
if not inference and key is None:
|
|
255
|
-
raise ValueError("Need key when in training mode")
|
|
256
|
-
H, W, C = x.shape
|
|
257
|
-
to_pad_W = (window_size[1] - W % window_size[1]) % window_size[1]
|
|
258
|
-
to_pad_H = (window_size[0] - H % window_size[0]) % window_size[0]
|
|
259
|
-
x = jnp.pad(x, ((0, to_pad_H), (0, to_pad_W), (0, 0)))
|
|
260
|
-
pad_H, pad_W, _ = x.shape
|
|
261
|
-
|
|
262
|
-
shift_size = shift_size.copy()
|
|
263
|
-
if window_size[0] >= pad_H:
|
|
264
|
-
shift_size[0] = 0
|
|
265
|
-
if window_size[1] >= pad_W:
|
|
266
|
-
shift_size[1] = 0
|
|
267
|
-
|
|
268
|
-
# cyclic shift
|
|
269
|
-
if sum(shift_size) > 0:
|
|
270
|
-
x = jnp.roll(x, shift=(-shift_size[0], -shift_size[1]), axis=(0, 1))
|
|
271
|
-
|
|
272
|
-
# partition windows
|
|
273
|
-
num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1])
|
|
274
|
-
x = jnp.reshape(
|
|
275
|
-
x,
|
|
276
|
-
(
|
|
277
|
-
pad_H // window_size[0],
|
|
278
|
-
window_size[0],
|
|
279
|
-
pad_W // window_size[1],
|
|
280
|
-
window_size[1],
|
|
281
|
-
C,
|
|
282
|
-
),
|
|
283
|
-
)
|
|
284
|
-
x = jnp.transpose(x, (0, 2, 1, 3, 4)).reshape(
|
|
285
|
-
num_windows, window_size[0] * window_size[1], C
|
|
286
|
-
)
|
|
287
|
-
|
|
288
|
-
# multi-head attention
|
|
289
|
-
if logit_scale is not None and qkv_bias is not None:
|
|
290
|
-
length = qkv_bias.size // 3
|
|
291
|
-
qkv_bias = qkv_bias.at[length : 2 * length].set(0.0)
|
|
292
|
-
|
|
293
|
-
def linear(x: Array, weight: Array, bias: Array | None):
|
|
294
|
-
output = x @ jnp.transpose(weight) # (in,) @ (in, out) -> (out,)
|
|
295
|
-
if bias is not None:
|
|
296
|
-
output = output + bias
|
|
297
|
-
return output
|
|
298
|
-
|
|
299
|
-
linear_pt = functools.partial(linear, weight=qkv_weight, bias=qkv_bias)
|
|
300
|
-
|
|
301
|
-
qkv = eqx.filter_vmap(eqx.filter_vmap(linear_pt))(x)
|
|
302
|
-
win_size, patches, _ = qkv.shape
|
|
303
|
-
qkv = jnp.transpose(
|
|
304
|
-
qkv.reshape(win_size, patches, 3, num_heads, C // num_heads), (2, 0, 3, 1, 4)
|
|
305
|
-
)
|
|
306
|
-
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
307
|
-
if logit_scale is not None:
|
|
308
|
-
# cosine attention
|
|
309
|
-
attn = normalize(q, axis=-1) @ jnp.transpose(
|
|
310
|
-
normalize(k, axis=-1), (0, 1, 3, 2)
|
|
311
|
-
)
|
|
312
|
-
# Clamp the logit scale exponent for stability
|
|
313
|
-
logit_scale = jnp.exp(jnp.minimum(logit_scale, jnp.log(jnp.array(100.0))))
|
|
314
|
-
attn = attn * logit_scale
|
|
315
|
-
else:
|
|
316
|
-
q = q * (C // num_heads) ** -0.5
|
|
317
|
-
# attn = q @ (jnp.transpose(normalize(k, axis=-1), (0, 1, 3, 2))) # Incorrect
|
|
318
|
-
attn = q @ jnp.transpose(k, (0, 1, 3, 2)) # Corrected: q @ k.T
|
|
319
|
-
|
|
320
|
-
# add relative position bias
|
|
321
|
-
attn = attn + relative_position_bias
|
|
322
|
-
|
|
323
|
-
if sum(shift_size) > 0:
|
|
324
|
-
attn_mask = create_attn_mask(pad_H, pad_W, window_size, shift_size)
|
|
325
|
-
attn_mask = attn_mask.reshape(
|
|
326
|
-
pad_H // window_size[0],
|
|
327
|
-
window_size[0],
|
|
328
|
-
pad_W // window_size[1],
|
|
329
|
-
window_size[1],
|
|
330
|
-
)
|
|
331
|
-
attn_mask = jnp.transpose(attn_mask, (0, 2, 1, 3)).reshape(
|
|
332
|
-
num_windows, window_size[0] * window_size[1]
|
|
333
|
-
)
|
|
334
|
-
attn_mask = jnp.expand_dims(attn_mask, axis=1) - jnp.expand_dims(
|
|
335
|
-
attn_mask, axis=2
|
|
336
|
-
)
|
|
337
|
-
attn_mask = jnp.where(attn_mask == 0, 0.0, -100.0)
|
|
338
|
-
|
|
339
|
-
attn = attn + attn_mask[:, None, :, :]
|
|
340
|
-
|
|
341
|
-
attn = jax.nn.softmax(attn, axis=-1)
|
|
342
|
-
if not inference:
|
|
343
|
-
assert key is not None, "key must be given if not inference"
|
|
344
|
-
key, subkey = jax.random.split(key)
|
|
345
|
-
attn = dropout_fn(attn, p=attention_dropout, inference=inference, key=subkey)
|
|
346
|
-
|
|
347
|
-
x = jnp.transpose(attn @ v, (0, 2, 1, 3)).reshape(
|
|
348
|
-
num_windows, window_size[0] * window_size[1], C
|
|
349
|
-
)
|
|
350
|
-
linear_pt_proj = functools.partial(linear, weight=proj_weight, bias=proj_bias)
|
|
351
|
-
|
|
352
|
-
x = eqx.filter_vmap(eqx.filter_vmap(linear_pt_proj))(x)
|
|
353
|
-
if not inference:
|
|
354
|
-
assert key is not None, "key must be given if not inference"
|
|
355
|
-
key, subkey = jax.random.split(key)
|
|
356
|
-
x = dropout_fn(x, p=dropout, inference=inference, key=subkey)
|
|
357
|
-
|
|
358
|
-
# reverse windows
|
|
359
|
-
x = x.reshape(
|
|
360
|
-
pad_H // window_size[0],
|
|
361
|
-
pad_W // window_size[1],
|
|
362
|
-
window_size[0],
|
|
363
|
-
window_size[1],
|
|
364
|
-
C,
|
|
365
|
-
)
|
|
366
|
-
x = jnp.transpose(x, (0, 2, 1, 3, 4)).reshape(pad_H, pad_W, C)
|
|
367
|
-
|
|
368
|
-
# reverse cyclic shift
|
|
369
|
-
if sum(shift_size) > 0:
|
|
370
|
-
x = jnp.roll(x, shift=(shift_size[0], shift_size[1]), axis=(0, 1))
|
|
371
|
-
|
|
372
|
-
# unpad features
|
|
373
|
-
x = x[:H, :W, :]
|
|
374
|
-
return x
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
import equinox as eqx
|
|
2
|
-
import jax
|
|
3
|
-
from jaxtyping import PRNGKeyArray, PyTree
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def kaiming_init_conv2d(model: PyTree, state: eqx.nn.State, key: PRNGKeyArray):
|
|
7
|
-
"""Applies Kaiming He normal initialization to Conv2d weights."""
|
|
8
|
-
# Filter function to identify Conv2d layers
|
|
9
|
-
is_conv2d = lambda x: isinstance(x, eqx.nn.Conv2d)
|
|
10
|
-
|
|
11
|
-
# Function to get weights (leaves) based on the filter
|
|
12
|
-
def get_weights(model):
|
|
13
|
-
return [
|
|
14
|
-
x.weight for x in jax.tree.leaves(model, is_leaf=is_conv2d) if is_conv2d(x)
|
|
15
|
-
]
|
|
16
|
-
|
|
17
|
-
# Get the list of current weights
|
|
18
|
-
weights = get_weights(model)
|
|
19
|
-
if not weights: # If no Conv2d layers found
|
|
20
|
-
return model, state
|
|
21
|
-
|
|
22
|
-
# Create new weights using He initializer
|
|
23
|
-
initializer = jax.nn.initializers.he_normal()
|
|
24
|
-
# Split key for each weight tensor
|
|
25
|
-
subkeys = jax.random.split(key, len(weights))
|
|
26
|
-
new_weights = [
|
|
27
|
-
initializer(subkeys[i], w.shape, w.dtype) # Use original weight's dtype
|
|
28
|
-
for i, w in enumerate(weights)
|
|
29
|
-
]
|
|
30
|
-
|
|
31
|
-
# Replace old weights with new weights in the model pytree
|
|
32
|
-
model = eqx.tree_at(get_weights, model, new_weights)
|
|
33
|
-
|
|
34
|
-
return model, state
|
jaxonlayers/functions/masking.py
DELETED
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
import jax.numpy as jnp
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
def canonical_mask(
|
|
5
|
-
mask,
|
|
6
|
-
mask_name,
|
|
7
|
-
other_name="",
|
|
8
|
-
other_type=None,
|
|
9
|
-
target_type=jnp.float32,
|
|
10
|
-
other_mask=None,
|
|
11
|
-
check_other=True,
|
|
12
|
-
):
|
|
13
|
-
if mask is None:
|
|
14
|
-
return None
|
|
15
|
-
if mask.dtype == bool:
|
|
16
|
-
additive_mask = jnp.where(mask, -jnp.inf, 0.0).astype(target_type)
|
|
17
|
-
return additive_mask
|
|
18
|
-
elif jnp.issubdtype(mask.dtype, jnp.integer) or jnp.issubdtype(
|
|
19
|
-
mask.dtype, jnp.floating
|
|
20
|
-
):
|
|
21
|
-
return mask.astype(target_type)
|
|
22
|
-
else:
|
|
23
|
-
raise TypeError(
|
|
24
|
-
f"{mask_name} must be bool, int, or float tensor, but got {mask.dtype}"
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def canonical_key_padding_mask(
|
|
29
|
-
key_padding_mask, attn_mask=None, query_dtype=jnp.float32
|
|
30
|
-
):
|
|
31
|
-
"""Wrapper for canonicalizing key_padding_mask"""
|
|
32
|
-
return canonical_mask(
|
|
33
|
-
mask=key_padding_mask,
|
|
34
|
-
mask_name="key_padding_mask",
|
|
35
|
-
other_name="attn_mask",
|
|
36
|
-
other_mask=attn_mask,
|
|
37
|
-
target_type=query_dtype,
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def canonical_attn_mask(attn_mask, query_dtype=jnp.float32):
|
|
42
|
-
"""Wrapper for canonicalizing attn_mask"""
|
|
43
|
-
return canonical_mask(
|
|
44
|
-
mask=attn_mask,
|
|
45
|
-
mask_name="attn_mask",
|
|
46
|
-
other_type=None,
|
|
47
|
-
other_name="",
|
|
48
|
-
target_type=query_dtype,
|
|
49
|
-
check_other=False,
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def build_attention_mask(context_length: int):
|
|
54
|
-
mask = jnp.tril(jnp.zeros((context_length, context_length)))
|
|
55
|
-
upper = jnp.triu(jnp.full((context_length, context_length), float("-inf")), k=1)
|
|
56
|
-
|
|
57
|
-
mask = mask + upper
|
|
58
|
-
return mask
|
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
import jax
|
|
2
|
-
import jax.numpy as jnp
|
|
3
|
-
from jaxtyping import Array, PRNGKeyArray
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def stochastic_depth(
|
|
7
|
-
input: Array,
|
|
8
|
-
p: float,
|
|
9
|
-
mode: str,
|
|
10
|
-
inference: bool,
|
|
11
|
-
key: PRNGKeyArray,
|
|
12
|
-
) -> Array:
|
|
13
|
-
if p < 0.0 or p > 1.0:
|
|
14
|
-
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
|
|
15
|
-
if mode not in ["batch", "row"]:
|
|
16
|
-
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
|
|
17
|
-
if inference or p == 0.0:
|
|
18
|
-
return input
|
|
19
|
-
survival_rate = 1.0 - p
|
|
20
|
-
if mode == "row":
|
|
21
|
-
size = [input.shape[0]] + [1] * (input.ndim - 1)
|
|
22
|
-
else:
|
|
23
|
-
size = [1] * input.ndim
|
|
24
|
-
noise = jax.random.bernoulli(key, p=survival_rate, shape=size).astype(input.dtype)
|
|
25
|
-
if survival_rate > 0.0:
|
|
26
|
-
noise = noise / survival_rate
|
|
27
|
-
return input * noise
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def dropout(
|
|
31
|
-
x: Array,
|
|
32
|
-
p: float,
|
|
33
|
-
inference: bool,
|
|
34
|
-
key: PRNGKeyArray | None = None,
|
|
35
|
-
) -> Array:
|
|
36
|
-
if isinstance(p, (int, float)) and p == 0:
|
|
37
|
-
inference = True
|
|
38
|
-
if inference:
|
|
39
|
-
return x
|
|
40
|
-
elif key is None:
|
|
41
|
-
raise RuntimeError(
|
|
42
|
-
"Dropout requires a key when running in non-deterministic mode."
|
|
43
|
-
)
|
|
44
|
-
else:
|
|
45
|
-
q = 1 - jax.lax.stop_gradient(p)
|
|
46
|
-
mask = jax.random.bernoulli(key, q, x.shape)
|
|
47
|
-
return jnp.where(mask, x / q, 0)
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
import jax
|
|
2
|
-
import jax.numpy as jnp
|
|
3
|
-
from jaxtyping import Array, Float
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def selective_scan(
|
|
7
|
-
x: Float[Array, "seq_length d_inner"],
|
|
8
|
-
delta: Float[Array, "seq_length d_inner"],
|
|
9
|
-
A: Float[Array, "d_inner d_state"],
|
|
10
|
-
B: Float[Array, "seq_length d_state"],
|
|
11
|
-
C: Float[Array, "seq_length d_state"],
|
|
12
|
-
D: Float[Array, " d_inner"],
|
|
13
|
-
) -> Float[Array, "seq_length d_inner"]:
|
|
14
|
-
L, d_inner = x.shape
|
|
15
|
-
_, d_state = A.shape
|
|
16
|
-
delta_A = jnp.exp(jnp.einsum("l d,d n -> l d n", delta, A))
|
|
17
|
-
delta_B_u = jnp.einsum("l d,l n,l d -> l d n", delta, B, x)
|
|
18
|
-
|
|
19
|
-
x_res = jnp.zeros(shape=(d_inner, d_state))
|
|
20
|
-
|
|
21
|
-
def step(x, i):
|
|
22
|
-
x = delta_A[i] * x + delta_B_u[i]
|
|
23
|
-
|
|
24
|
-
y = jnp.einsum("d n,n -> d", x, C[i, :])
|
|
25
|
-
return x, y
|
|
26
|
-
|
|
27
|
-
_, ys = jax.lax.scan(step, x_res, jnp.arange(L))
|
|
28
|
-
|
|
29
|
-
ys = ys + x * D
|
|
30
|
-
return ys
|
jaxonlayers/functions/utils.py
DELETED
jaxonlayers/layers/__init__.py
DELETED
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
from .attention import MultiheadAttention, SqueezeExcitation
|
|
2
|
-
from .convolution import ConvNormActivation
|
|
3
|
-
from .normalization import BatchNorm, LayerNorm, LocalResponseNormalization
|
|
4
|
-
from .regularization import StochasticDepth
|
|
5
|
-
from .state_space import SelectiveStateSpace
|
|
6
|
-
|
|
7
|
-
__all__ = [
|
|
8
|
-
"BatchNorm",
|
|
9
|
-
"LocalResponseNormalization",
|
|
10
|
-
"MultiheadAttention",
|
|
11
|
-
"SelectiveStateSpace",
|
|
12
|
-
"SqueezeExcitation",
|
|
13
|
-
"StochasticDepth",
|
|
14
|
-
"ConvNormActivation",
|
|
15
|
-
"LayerNorm",
|
|
16
|
-
]
|
jaxonlayers/layers/abstract.py
DELETED
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
import abc
|
|
2
|
-
|
|
3
|
-
import equinox as eqx
|
|
4
|
-
from jaxtyping import Array
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class AbstractNorm(eqx.Module):
|
|
8
|
-
@abc.abstractmethod
|
|
9
|
-
def __call__(self, x: Array, *_, **__) -> Array: ...
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class AbstractNormStateful(eqx.nn.StatefulLayer):
|
|
13
|
-
@abc.abstractmethod
|
|
14
|
-
def __call__(
|
|
15
|
-
self, x: Array, state: eqx.nn.State, *_, **__
|
|
16
|
-
) -> tuple[Array, eqx.nn.State]: ...
|