jaxonlayers 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxonlayers/layers/sequential.py +1 -1
- jaxonlayers/layers/transformer.py +17 -17
- {jaxonlayers-0.2.0.dist-info → jaxonlayers-0.2.2.dist-info}/METADATA +1 -1
- {jaxonlayers-0.2.0.dist-info → jaxonlayers-0.2.2.dist-info}/RECORD +5 -5
- {jaxonlayers-0.2.0.dist-info → jaxonlayers-0.2.2.dist-info}/WHEEL +0 -0
jaxonlayers/layers/sequential.py
CHANGED
|
@@ -40,7 +40,7 @@ class TransformerEncoderLayer(eqx.Module):
|
|
|
40
40
|
def __init__(
|
|
41
41
|
self,
|
|
42
42
|
d_model: int,
|
|
43
|
-
|
|
43
|
+
n_heads: int,
|
|
44
44
|
dim_feedforward: int = 2048,
|
|
45
45
|
dropout_p: float = 0.1,
|
|
46
46
|
activation: Callable = jax.nn.relu,
|
|
@@ -50,7 +50,7 @@ class TransformerEncoderLayer(eqx.Module):
|
|
|
50
50
|
inference: bool = False,
|
|
51
51
|
*,
|
|
52
52
|
key: PRNGKeyArray,
|
|
53
|
-
dtype: Any,
|
|
53
|
+
dtype: Any = None,
|
|
54
54
|
):
|
|
55
55
|
if dtype is None:
|
|
56
56
|
dtype = default_floating_dtype()
|
|
@@ -58,7 +58,7 @@ class TransformerEncoderLayer(eqx.Module):
|
|
|
58
58
|
self.inference = inference
|
|
59
59
|
mha_key, lin1_key, lin2_key = jax.random.split(key, 3)
|
|
60
60
|
self.self_attn = eqx.nn.MultiheadAttention(
|
|
61
|
-
|
|
61
|
+
n_heads,
|
|
62
62
|
d_model,
|
|
63
63
|
dropout_p=dropout_p,
|
|
64
64
|
use_query_bias=use_bias,
|
|
@@ -208,7 +208,7 @@ class TransformerDecoderLayer(eqx.Module):
|
|
|
208
208
|
def __init__(
|
|
209
209
|
self,
|
|
210
210
|
d_model: int,
|
|
211
|
-
|
|
211
|
+
n_heads: int,
|
|
212
212
|
dim_feedforward: int = 2048,
|
|
213
213
|
dropout_p: float = 0.1,
|
|
214
214
|
activation: Callable = jax.nn.relu,
|
|
@@ -218,7 +218,7 @@ class TransformerDecoderLayer(eqx.Module):
|
|
|
218
218
|
inference: bool = False,
|
|
219
219
|
*,
|
|
220
220
|
key: PRNGKeyArray,
|
|
221
|
-
dtype: Any,
|
|
221
|
+
dtype: Any = None,
|
|
222
222
|
):
|
|
223
223
|
if dtype is None:
|
|
224
224
|
dtype = default_floating_dtype()
|
|
@@ -227,7 +227,7 @@ class TransformerDecoderLayer(eqx.Module):
|
|
|
227
227
|
|
|
228
228
|
mha_key1, mha_key2, lin1_key, lin2_key = jax.random.split(key, 4)
|
|
229
229
|
self.self_attn = eqx.nn.MultiheadAttention(
|
|
230
|
-
|
|
230
|
+
n_heads,
|
|
231
231
|
d_model,
|
|
232
232
|
dropout_p=dropout_p,
|
|
233
233
|
use_query_bias=use_bias,
|
|
@@ -239,7 +239,7 @@ class TransformerDecoderLayer(eqx.Module):
|
|
|
239
239
|
dtype=dtype,
|
|
240
240
|
)
|
|
241
241
|
self.multihead_attn = eqx.nn.MultiheadAttention(
|
|
242
|
-
|
|
242
|
+
n_heads,
|
|
243
243
|
d_model,
|
|
244
244
|
dropout_p=dropout_p,
|
|
245
245
|
use_query_bias=use_bias,
|
|
@@ -455,7 +455,7 @@ class TransformerEncoder(eqx.Module):
|
|
|
455
455
|
def __init__(
|
|
456
456
|
self,
|
|
457
457
|
d_model: int,
|
|
458
|
-
|
|
458
|
+
n_heads: int,
|
|
459
459
|
num_layers: int = 6,
|
|
460
460
|
dim_feedforward: int = 2048,
|
|
461
461
|
dropout_p: float = 0.1,
|
|
@@ -467,7 +467,7 @@ class TransformerEncoder(eqx.Module):
|
|
|
467
467
|
inference: bool = False,
|
|
468
468
|
*,
|
|
469
469
|
key: PRNGKeyArray,
|
|
470
|
-
dtype: Any,
|
|
470
|
+
dtype: Any = None,
|
|
471
471
|
):
|
|
472
472
|
if dtype is None:
|
|
473
473
|
dtype = default_floating_dtype()
|
|
@@ -478,7 +478,7 @@ class TransformerEncoder(eqx.Module):
|
|
|
478
478
|
self.layers = [
|
|
479
479
|
TransformerEncoderLayer(
|
|
480
480
|
d_model=d_model,
|
|
481
|
-
|
|
481
|
+
n_heads=n_heads,
|
|
482
482
|
dim_feedforward=dim_feedforward,
|
|
483
483
|
dropout_p=dropout_p,
|
|
484
484
|
activation=activation,
|
|
@@ -534,7 +534,7 @@ class TransformerDecoder(eqx.Module):
|
|
|
534
534
|
def __init__(
|
|
535
535
|
self,
|
|
536
536
|
d_model: int,
|
|
537
|
-
|
|
537
|
+
n_heads: int,
|
|
538
538
|
num_layers: int = 6,
|
|
539
539
|
dim_feedforward: int = 2048,
|
|
540
540
|
dropout_p: float = 0.1,
|
|
@@ -546,7 +546,7 @@ class TransformerDecoder(eqx.Module):
|
|
|
546
546
|
inference: bool = False,
|
|
547
547
|
*,
|
|
548
548
|
key: PRNGKeyArray,
|
|
549
|
-
dtype: Any,
|
|
549
|
+
dtype: Any = None,
|
|
550
550
|
):
|
|
551
551
|
if dtype is None:
|
|
552
552
|
dtype = default_floating_dtype()
|
|
@@ -557,7 +557,7 @@ class TransformerDecoder(eqx.Module):
|
|
|
557
557
|
self.layers = [
|
|
558
558
|
TransformerDecoderLayer(
|
|
559
559
|
d_model=d_model,
|
|
560
|
-
|
|
560
|
+
n_heads=n_heads,
|
|
561
561
|
dim_feedforward=dim_feedforward,
|
|
562
562
|
dropout_p=dropout_p,
|
|
563
563
|
activation=activation,
|
|
@@ -627,7 +627,7 @@ class Transformer(eqx.Module):
|
|
|
627
627
|
def __init__(
|
|
628
628
|
self,
|
|
629
629
|
d_model: int,
|
|
630
|
-
|
|
630
|
+
n_heads: int,
|
|
631
631
|
num_encoder_layers: int = 6,
|
|
632
632
|
num_decoder_layers: int = 6,
|
|
633
633
|
dim_feedforward: int = 2048,
|
|
@@ -639,7 +639,7 @@ class Transformer(eqx.Module):
|
|
|
639
639
|
inference: bool = False,
|
|
640
640
|
*,
|
|
641
641
|
key: PRNGKeyArray,
|
|
642
|
-
dtype: Any,
|
|
642
|
+
dtype: Any = None,
|
|
643
643
|
):
|
|
644
644
|
if dtype is None:
|
|
645
645
|
dtype = default_floating_dtype()
|
|
@@ -650,7 +650,7 @@ class Transformer(eqx.Module):
|
|
|
650
650
|
|
|
651
651
|
self.encoder = TransformerEncoder(
|
|
652
652
|
d_model=d_model,
|
|
653
|
-
|
|
653
|
+
n_heads=n_heads,
|
|
654
654
|
num_layers=num_encoder_layers,
|
|
655
655
|
dim_feedforward=dim_feedforward,
|
|
656
656
|
dropout_p=dropout_p,
|
|
@@ -666,7 +666,7 @@ class Transformer(eqx.Module):
|
|
|
666
666
|
|
|
667
667
|
self.decoder = TransformerDecoder(
|
|
668
668
|
d_model=d_model,
|
|
669
|
-
|
|
669
|
+
n_heads=n_heads,
|
|
670
670
|
num_layers=num_decoder_layers,
|
|
671
671
|
dim_feedforward=dim_feedforward,
|
|
672
672
|
dropout_p=dropout_p,
|
|
@@ -14,9 +14,9 @@ jaxonlayers/layers/attention.py,sha256=RgtpzBPxJ4tDcUjiq_Wh_7GJndmBY6UKtbEuHGLA1
|
|
|
14
14
|
jaxonlayers/layers/convolution.py,sha256=k0dMFBDjzycB7UNuyHqKihJtBa6u93V6OLxyUUyipN4,3247
|
|
15
15
|
jaxonlayers/layers/normalization.py,sha256=3aGzNzDN05A72ZHLUM2w9WpicLtGsjzj1l0jhuyn63U,8379
|
|
16
16
|
jaxonlayers/layers/regularization.py,sha256=ZrvtBJPH84xuxrxEbZc7TBxjp8OvKEv4ecan5s8F9zs,563
|
|
17
|
-
jaxonlayers/layers/sequential.py,sha256=
|
|
17
|
+
jaxonlayers/layers/sequential.py,sha256=xBZavhSra0oZUZjSlThzIGER5xR62n46mgTkPqRV2Y0,2843
|
|
18
18
|
jaxonlayers/layers/state_space.py,sha256=oDVRbduNtU48Q4rLd-XywZcqVN0QYTlq1UUhOXcGLoo,2537
|
|
19
|
-
jaxonlayers/layers/transformer.py,sha256=
|
|
20
|
-
jaxonlayers-0.2.
|
|
21
|
-
jaxonlayers-0.2.
|
|
22
|
-
jaxonlayers-0.2.
|
|
19
|
+
jaxonlayers/layers/transformer.py,sha256=Syuh_kN-I-Gg-C20trQcNio1WEG88OHDRl3ShFaQNqQ,21848
|
|
20
|
+
jaxonlayers-0.2.2.dist-info/METADATA,sha256=sYqihkNlwRrIMMPtmi8zIMPk4TOef3H6ZHB44gqKINM,565
|
|
21
|
+
jaxonlayers-0.2.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
22
|
+
jaxonlayers-0.2.2.dist-info/RECORD,,
|
|
File without changes
|