jaxonlayers 0.2.0__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/PKG-INFO +1 -1
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/sequential.py +1 -1
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/transformer.py +17 -17
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/pyproject.toml +1 -1
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_transformer.py +30 -30
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/.gitignore +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/.pre-commit-config.yaml +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/.python-version +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/README.md +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/__init__.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/__init__.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/attention.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/embedding.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/initialization.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/masking.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/normalization.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/regularization.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/state_space.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/utils.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/__init__.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/abstract.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/attention.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/convolution.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/normalization.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/regularization.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/state_space.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/__init__.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_attention.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_batch_norm.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_efficientnet_layers.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_layernorm.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_local_response_normalisation.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_mha.py +0 -0
- {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/uv.lock +0 -0
|
@@ -40,7 +40,7 @@ class TransformerEncoderLayer(eqx.Module):
|
|
|
40
40
|
def __init__(
|
|
41
41
|
self,
|
|
42
42
|
d_model: int,
|
|
43
|
-
|
|
43
|
+
n_heads: int,
|
|
44
44
|
dim_feedforward: int = 2048,
|
|
45
45
|
dropout_p: float = 0.1,
|
|
46
46
|
activation: Callable = jax.nn.relu,
|
|
@@ -50,7 +50,7 @@ class TransformerEncoderLayer(eqx.Module):
|
|
|
50
50
|
inference: bool = False,
|
|
51
51
|
*,
|
|
52
52
|
key: PRNGKeyArray,
|
|
53
|
-
dtype: Any,
|
|
53
|
+
dtype: Any = None,
|
|
54
54
|
):
|
|
55
55
|
if dtype is None:
|
|
56
56
|
dtype = default_floating_dtype()
|
|
@@ -58,7 +58,7 @@ class TransformerEncoderLayer(eqx.Module):
|
|
|
58
58
|
self.inference = inference
|
|
59
59
|
mha_key, lin1_key, lin2_key = jax.random.split(key, 3)
|
|
60
60
|
self.self_attn = eqx.nn.MultiheadAttention(
|
|
61
|
-
|
|
61
|
+
n_heads,
|
|
62
62
|
d_model,
|
|
63
63
|
dropout_p=dropout_p,
|
|
64
64
|
use_query_bias=use_bias,
|
|
@@ -208,7 +208,7 @@ class TransformerDecoderLayer(eqx.Module):
|
|
|
208
208
|
def __init__(
|
|
209
209
|
self,
|
|
210
210
|
d_model: int,
|
|
211
|
-
|
|
211
|
+
n_heads: int,
|
|
212
212
|
dim_feedforward: int = 2048,
|
|
213
213
|
dropout_p: float = 0.1,
|
|
214
214
|
activation: Callable = jax.nn.relu,
|
|
@@ -218,7 +218,7 @@ class TransformerDecoderLayer(eqx.Module):
|
|
|
218
218
|
inference: bool = False,
|
|
219
219
|
*,
|
|
220
220
|
key: PRNGKeyArray,
|
|
221
|
-
dtype: Any,
|
|
221
|
+
dtype: Any = None,
|
|
222
222
|
):
|
|
223
223
|
if dtype is None:
|
|
224
224
|
dtype = default_floating_dtype()
|
|
@@ -227,7 +227,7 @@ class TransformerDecoderLayer(eqx.Module):
|
|
|
227
227
|
|
|
228
228
|
mha_key1, mha_key2, lin1_key, lin2_key = jax.random.split(key, 4)
|
|
229
229
|
self.self_attn = eqx.nn.MultiheadAttention(
|
|
230
|
-
|
|
230
|
+
n_heads,
|
|
231
231
|
d_model,
|
|
232
232
|
dropout_p=dropout_p,
|
|
233
233
|
use_query_bias=use_bias,
|
|
@@ -239,7 +239,7 @@ class TransformerDecoderLayer(eqx.Module):
|
|
|
239
239
|
dtype=dtype,
|
|
240
240
|
)
|
|
241
241
|
self.multihead_attn = eqx.nn.MultiheadAttention(
|
|
242
|
-
|
|
242
|
+
n_heads,
|
|
243
243
|
d_model,
|
|
244
244
|
dropout_p=dropout_p,
|
|
245
245
|
use_query_bias=use_bias,
|
|
@@ -455,7 +455,7 @@ class TransformerEncoder(eqx.Module):
|
|
|
455
455
|
def __init__(
|
|
456
456
|
self,
|
|
457
457
|
d_model: int,
|
|
458
|
-
|
|
458
|
+
n_heads: int,
|
|
459
459
|
num_layers: int = 6,
|
|
460
460
|
dim_feedforward: int = 2048,
|
|
461
461
|
dropout_p: float = 0.1,
|
|
@@ -467,7 +467,7 @@ class TransformerEncoder(eqx.Module):
|
|
|
467
467
|
inference: bool = False,
|
|
468
468
|
*,
|
|
469
469
|
key: PRNGKeyArray,
|
|
470
|
-
dtype: Any,
|
|
470
|
+
dtype: Any = None,
|
|
471
471
|
):
|
|
472
472
|
if dtype is None:
|
|
473
473
|
dtype = default_floating_dtype()
|
|
@@ -478,7 +478,7 @@ class TransformerEncoder(eqx.Module):
|
|
|
478
478
|
self.layers = [
|
|
479
479
|
TransformerEncoderLayer(
|
|
480
480
|
d_model=d_model,
|
|
481
|
-
|
|
481
|
+
n_heads=n_heads,
|
|
482
482
|
dim_feedforward=dim_feedforward,
|
|
483
483
|
dropout_p=dropout_p,
|
|
484
484
|
activation=activation,
|
|
@@ -534,7 +534,7 @@ class TransformerDecoder(eqx.Module):
|
|
|
534
534
|
def __init__(
|
|
535
535
|
self,
|
|
536
536
|
d_model: int,
|
|
537
|
-
|
|
537
|
+
n_heads: int,
|
|
538
538
|
num_layers: int = 6,
|
|
539
539
|
dim_feedforward: int = 2048,
|
|
540
540
|
dropout_p: float = 0.1,
|
|
@@ -546,7 +546,7 @@ class TransformerDecoder(eqx.Module):
|
|
|
546
546
|
inference: bool = False,
|
|
547
547
|
*,
|
|
548
548
|
key: PRNGKeyArray,
|
|
549
|
-
dtype: Any,
|
|
549
|
+
dtype: Any = None,
|
|
550
550
|
):
|
|
551
551
|
if dtype is None:
|
|
552
552
|
dtype = default_floating_dtype()
|
|
@@ -557,7 +557,7 @@ class TransformerDecoder(eqx.Module):
|
|
|
557
557
|
self.layers = [
|
|
558
558
|
TransformerDecoderLayer(
|
|
559
559
|
d_model=d_model,
|
|
560
|
-
|
|
560
|
+
n_heads=n_heads,
|
|
561
561
|
dim_feedforward=dim_feedforward,
|
|
562
562
|
dropout_p=dropout_p,
|
|
563
563
|
activation=activation,
|
|
@@ -627,7 +627,7 @@ class Transformer(eqx.Module):
|
|
|
627
627
|
def __init__(
|
|
628
628
|
self,
|
|
629
629
|
d_model: int,
|
|
630
|
-
|
|
630
|
+
n_heads: int,
|
|
631
631
|
num_encoder_layers: int = 6,
|
|
632
632
|
num_decoder_layers: int = 6,
|
|
633
633
|
dim_feedforward: int = 2048,
|
|
@@ -639,7 +639,7 @@ class Transformer(eqx.Module):
|
|
|
639
639
|
inference: bool = False,
|
|
640
640
|
*,
|
|
641
641
|
key: PRNGKeyArray,
|
|
642
|
-
dtype: Any,
|
|
642
|
+
dtype: Any = None,
|
|
643
643
|
):
|
|
644
644
|
if dtype is None:
|
|
645
645
|
dtype = default_floating_dtype()
|
|
@@ -650,7 +650,7 @@ class Transformer(eqx.Module):
|
|
|
650
650
|
|
|
651
651
|
self.encoder = TransformerEncoder(
|
|
652
652
|
d_model=d_model,
|
|
653
|
-
|
|
653
|
+
n_heads=n_heads,
|
|
654
654
|
num_layers=num_encoder_layers,
|
|
655
655
|
dim_feedforward=dim_feedforward,
|
|
656
656
|
dropout_p=dropout_p,
|
|
@@ -666,7 +666,7 @@ class Transformer(eqx.Module):
|
|
|
666
666
|
|
|
667
667
|
self.decoder = TransformerDecoder(
|
|
668
668
|
d_model=d_model,
|
|
669
|
-
|
|
669
|
+
n_heads=n_heads,
|
|
670
670
|
num_layers=num_decoder_layers,
|
|
671
671
|
dim_feedforward=dim_feedforward,
|
|
672
672
|
dropout_p=dropout_p,
|
|
@@ -25,12 +25,12 @@ class TestTransformerEncoderLayer:
|
|
|
25
25
|
)
|
|
26
26
|
def test_masking(self, is_causal, use_explicit_mask):
|
|
27
27
|
d_model = 64
|
|
28
|
-
|
|
28
|
+
n_heads = 4
|
|
29
29
|
seq_len = 10
|
|
30
30
|
|
|
31
31
|
layer = TransformerEncoderLayer(
|
|
32
32
|
d_model=d_model,
|
|
33
|
-
|
|
33
|
+
n_heads=n_heads,
|
|
34
34
|
key=jax.random.key(0),
|
|
35
35
|
dtype=jnp.float32,
|
|
36
36
|
)
|
|
@@ -43,12 +43,12 @@ class TestTransformerEncoderLayer:
|
|
|
43
43
|
|
|
44
44
|
def test_jit_no_retrace(self):
|
|
45
45
|
d_model = 64
|
|
46
|
-
|
|
46
|
+
n_heads = 4
|
|
47
47
|
seq_len = 10
|
|
48
48
|
|
|
49
49
|
layer = TransformerEncoderLayer(
|
|
50
50
|
d_model=d_model,
|
|
51
|
-
|
|
51
|
+
n_heads=n_heads,
|
|
52
52
|
key=jax.random.key(0),
|
|
53
53
|
dtype=jnp.float32,
|
|
54
54
|
)
|
|
@@ -84,13 +84,13 @@ class TestTransformerDecoderLayer:
|
|
|
84
84
|
self, tgt_is_causal, memory_is_causal, use_tgt_mask, use_memory_mask
|
|
85
85
|
):
|
|
86
86
|
d_model = 64
|
|
87
|
-
|
|
87
|
+
n_heads = 4
|
|
88
88
|
tgt_len = 10
|
|
89
89
|
src_len = 12
|
|
90
90
|
|
|
91
91
|
layer = TransformerDecoderLayer(
|
|
92
92
|
d_model=d_model,
|
|
93
|
-
|
|
93
|
+
n_heads=n_heads,
|
|
94
94
|
key=jax.random.key(0),
|
|
95
95
|
dtype=jnp.float32,
|
|
96
96
|
)
|
|
@@ -116,13 +116,13 @@ class TestTransformerDecoderLayer:
|
|
|
116
116
|
|
|
117
117
|
def test_with_process_heads(self):
|
|
118
118
|
d_model = 64
|
|
119
|
-
|
|
119
|
+
n_heads = 4
|
|
120
120
|
tgt_len = 10
|
|
121
121
|
src_len = 12
|
|
122
122
|
|
|
123
123
|
layer = TransformerDecoderLayer(
|
|
124
124
|
d_model=d_model,
|
|
125
|
-
|
|
125
|
+
n_heads=n_heads,
|
|
126
126
|
key=jax.random.key(0),
|
|
127
127
|
dtype=jnp.float32,
|
|
128
128
|
)
|
|
@@ -144,13 +144,13 @@ class TestTransformerDecoderLayer:
|
|
|
144
144
|
|
|
145
145
|
def test_jit_no_retrace(self):
|
|
146
146
|
d_model = 64
|
|
147
|
-
|
|
147
|
+
n_heads = 4
|
|
148
148
|
tgt_len = 10
|
|
149
149
|
src_len = 12
|
|
150
150
|
|
|
151
151
|
layer = TransformerDecoderLayer(
|
|
152
152
|
d_model=d_model,
|
|
153
|
-
|
|
153
|
+
n_heads=n_heads,
|
|
154
154
|
key=jax.random.key(0),
|
|
155
155
|
dtype=jnp.float32,
|
|
156
156
|
)
|
|
@@ -184,13 +184,13 @@ class TestTransformerEncoder:
|
|
|
184
184
|
)
|
|
185
185
|
def test_masking(self, is_causal, use_explicit_mask):
|
|
186
186
|
d_model = 64
|
|
187
|
-
|
|
187
|
+
n_heads = 4
|
|
188
188
|
num_layers = 3
|
|
189
189
|
seq_len = 10
|
|
190
190
|
|
|
191
191
|
encoder = TransformerEncoder(
|
|
192
192
|
d_model=d_model,
|
|
193
|
-
|
|
193
|
+
n_heads=n_heads,
|
|
194
194
|
num_layers=num_layers,
|
|
195
195
|
key=jax.random.key(0),
|
|
196
196
|
dtype=jnp.float32,
|
|
@@ -204,13 +204,13 @@ class TestTransformerEncoder:
|
|
|
204
204
|
|
|
205
205
|
def test_with_process_heads(self):
|
|
206
206
|
d_model = 64
|
|
207
|
-
|
|
207
|
+
n_heads = 4
|
|
208
208
|
num_layers = 3
|
|
209
209
|
seq_len = 10
|
|
210
210
|
|
|
211
211
|
encoder = TransformerEncoder(
|
|
212
212
|
d_model=d_model,
|
|
213
|
-
|
|
213
|
+
n_heads=n_heads,
|
|
214
214
|
num_layers=num_layers,
|
|
215
215
|
key=jax.random.key(0),
|
|
216
216
|
dtype=jnp.float32,
|
|
@@ -226,13 +226,13 @@ class TestTransformerEncoder:
|
|
|
226
226
|
|
|
227
227
|
def test_jit_no_retrace(self):
|
|
228
228
|
d_model = 64
|
|
229
|
-
|
|
229
|
+
n_heads = 4
|
|
230
230
|
num_layers = 3
|
|
231
231
|
seq_len = 10
|
|
232
232
|
|
|
233
233
|
encoder = TransformerEncoder(
|
|
234
234
|
d_model=d_model,
|
|
235
|
-
|
|
235
|
+
n_heads=n_heads,
|
|
236
236
|
num_layers=num_layers,
|
|
237
237
|
key=jax.random.key(0),
|
|
238
238
|
dtype=jnp.float32,
|
|
@@ -269,14 +269,14 @@ class TestTransformerDecoder:
|
|
|
269
269
|
self, tgt_is_causal, memory_is_causal, use_tgt_mask, use_memory_mask
|
|
270
270
|
):
|
|
271
271
|
d_model = 64
|
|
272
|
-
|
|
272
|
+
n_heads = 4
|
|
273
273
|
num_layers = 3
|
|
274
274
|
tgt_len = 10
|
|
275
275
|
src_len = 12
|
|
276
276
|
|
|
277
277
|
decoder = TransformerDecoder(
|
|
278
278
|
d_model=d_model,
|
|
279
|
-
|
|
279
|
+
n_heads=n_heads,
|
|
280
280
|
num_layers=num_layers,
|
|
281
281
|
key=jax.random.key(0),
|
|
282
282
|
dtype=jnp.float32,
|
|
@@ -303,14 +303,14 @@ class TestTransformerDecoder:
|
|
|
303
303
|
|
|
304
304
|
def test_with_process_heads(self):
|
|
305
305
|
d_model = 64
|
|
306
|
-
|
|
306
|
+
n_heads = 4
|
|
307
307
|
num_layers = 3
|
|
308
308
|
tgt_len = 10
|
|
309
309
|
src_len = 12
|
|
310
310
|
|
|
311
311
|
decoder = TransformerDecoder(
|
|
312
312
|
d_model=d_model,
|
|
313
|
-
|
|
313
|
+
n_heads=n_heads,
|
|
314
314
|
num_layers=num_layers,
|
|
315
315
|
key=jax.random.key(0),
|
|
316
316
|
dtype=jnp.float32,
|
|
@@ -333,14 +333,14 @@ class TestTransformerDecoder:
|
|
|
333
333
|
|
|
334
334
|
def test_jit_no_retrace(self):
|
|
335
335
|
d_model = 64
|
|
336
|
-
|
|
336
|
+
n_heads = 4
|
|
337
337
|
num_layers = 3
|
|
338
338
|
tgt_len = 10
|
|
339
339
|
src_len = 12
|
|
340
340
|
|
|
341
341
|
decoder = TransformerDecoder(
|
|
342
342
|
d_model=d_model,
|
|
343
|
-
|
|
343
|
+
n_heads=n_heads,
|
|
344
344
|
num_layers=num_layers,
|
|
345
345
|
key=jax.random.key(0),
|
|
346
346
|
dtype=jnp.float32,
|
|
@@ -383,13 +383,13 @@ class TestTransformer:
|
|
|
383
383
|
use_memory_mask,
|
|
384
384
|
):
|
|
385
385
|
d_model = 64
|
|
386
|
-
|
|
386
|
+
n_heads = 4
|
|
387
387
|
src_len = 12
|
|
388
388
|
tgt_len = 10
|
|
389
389
|
|
|
390
390
|
transformer = Transformer(
|
|
391
391
|
d_model=d_model,
|
|
392
|
-
|
|
392
|
+
n_heads=n_heads,
|
|
393
393
|
key=jax.random.key(0),
|
|
394
394
|
dtype=jnp.float32,
|
|
395
395
|
)
|
|
@@ -418,13 +418,13 @@ class TestTransformer:
|
|
|
418
418
|
|
|
419
419
|
def test_with_process_heads(self):
|
|
420
420
|
d_model = 64
|
|
421
|
-
|
|
421
|
+
n_heads = 4
|
|
422
422
|
src_len = 12
|
|
423
423
|
tgt_len = 10
|
|
424
424
|
|
|
425
425
|
transformer = Transformer(
|
|
426
426
|
d_model=d_model,
|
|
427
|
-
|
|
427
|
+
n_heads=n_heads,
|
|
428
428
|
key=jax.random.key(0),
|
|
429
429
|
dtype=jnp.float32,
|
|
430
430
|
)
|
|
@@ -455,13 +455,13 @@ class TestTransformer:
|
|
|
455
455
|
)
|
|
456
456
|
def test_activations(self, activation):
|
|
457
457
|
d_model = 64
|
|
458
|
-
|
|
458
|
+
n_heads = 4
|
|
459
459
|
src_len = 12
|
|
460
460
|
tgt_len = 10
|
|
461
461
|
|
|
462
462
|
transformer = Transformer(
|
|
463
463
|
d_model=d_model,
|
|
464
|
-
|
|
464
|
+
n_heads=n_heads,
|
|
465
465
|
activation=activation,
|
|
466
466
|
key=jax.random.key(0),
|
|
467
467
|
dtype=jnp.float32,
|
|
@@ -475,13 +475,13 @@ class TestTransformer:
|
|
|
475
475
|
|
|
476
476
|
def test_jit_no_retrace(self):
|
|
477
477
|
d_model = 64
|
|
478
|
-
|
|
478
|
+
n_heads = 4
|
|
479
479
|
src_len = 12
|
|
480
480
|
tgt_len = 10
|
|
481
481
|
|
|
482
482
|
transformer = Transformer(
|
|
483
483
|
d_model=d_model,
|
|
484
|
-
|
|
484
|
+
n_heads=n_heads,
|
|
485
485
|
key=jax.random.key(0),
|
|
486
486
|
dtype=jnp.float32,
|
|
487
487
|
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|