jaxonlayers 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,7 +40,7 @@ class TransformerEncoderLayer(eqx.Module):
40
40
  def __init__(
41
41
  self,
42
42
  d_model: int,
43
- nhead: int,
43
+ n_heads: int,
44
44
  dim_feedforward: int = 2048,
45
45
  dropout_p: float = 0.1,
46
46
  activation: Callable = jax.nn.relu,
@@ -50,7 +50,7 @@ class TransformerEncoderLayer(eqx.Module):
50
50
  inference: bool = False,
51
51
  *,
52
52
  key: PRNGKeyArray,
53
- dtype: Any,
53
+ dtype: Any = None,
54
54
  ):
55
55
  if dtype is None:
56
56
  dtype = default_floating_dtype()
@@ -58,7 +58,7 @@ class TransformerEncoderLayer(eqx.Module):
58
58
  self.inference = inference
59
59
  mha_key, lin1_key, lin2_key = jax.random.split(key, 3)
60
60
  self.self_attn = eqx.nn.MultiheadAttention(
61
- nhead,
61
+ n_heads,
62
62
  d_model,
63
63
  dropout_p=dropout_p,
64
64
  use_query_bias=use_bias,
@@ -208,7 +208,7 @@ class TransformerDecoderLayer(eqx.Module):
208
208
  def __init__(
209
209
  self,
210
210
  d_model: int,
211
- nhead: int,
211
+ n_heads: int,
212
212
  dim_feedforward: int = 2048,
213
213
  dropout_p: float = 0.1,
214
214
  activation: Callable = jax.nn.relu,
@@ -218,7 +218,7 @@ class TransformerDecoderLayer(eqx.Module):
218
218
  inference: bool = False,
219
219
  *,
220
220
  key: PRNGKeyArray,
221
- dtype: Any,
221
+ dtype: Any = None,
222
222
  ):
223
223
  if dtype is None:
224
224
  dtype = default_floating_dtype()
@@ -227,7 +227,7 @@ class TransformerDecoderLayer(eqx.Module):
227
227
 
228
228
  mha_key1, mha_key2, lin1_key, lin2_key = jax.random.split(key, 4)
229
229
  self.self_attn = eqx.nn.MultiheadAttention(
230
- nhead,
230
+ n_heads,
231
231
  d_model,
232
232
  dropout_p=dropout_p,
233
233
  use_query_bias=use_bias,
@@ -239,7 +239,7 @@ class TransformerDecoderLayer(eqx.Module):
239
239
  dtype=dtype,
240
240
  )
241
241
  self.multihead_attn = eqx.nn.MultiheadAttention(
242
- nhead,
242
+ n_heads,
243
243
  d_model,
244
244
  dropout_p=dropout_p,
245
245
  use_query_bias=use_bias,
@@ -455,7 +455,7 @@ class TransformerEncoder(eqx.Module):
455
455
  def __init__(
456
456
  self,
457
457
  d_model: int,
458
- nhead: int,
458
+ n_heads: int,
459
459
  num_layers: int = 6,
460
460
  dim_feedforward: int = 2048,
461
461
  dropout_p: float = 0.1,
@@ -467,7 +467,7 @@ class TransformerEncoder(eqx.Module):
467
467
  inference: bool = False,
468
468
  *,
469
469
  key: PRNGKeyArray,
470
- dtype: Any,
470
+ dtype: Any = None,
471
471
  ):
472
472
  if dtype is None:
473
473
  dtype = default_floating_dtype()
@@ -478,7 +478,7 @@ class TransformerEncoder(eqx.Module):
478
478
  self.layers = [
479
479
  TransformerEncoderLayer(
480
480
  d_model=d_model,
481
- nhead=nhead,
481
+ n_heads=n_heads,
482
482
  dim_feedforward=dim_feedforward,
483
483
  dropout_p=dropout_p,
484
484
  activation=activation,
@@ -534,7 +534,7 @@ class TransformerDecoder(eqx.Module):
534
534
  def __init__(
535
535
  self,
536
536
  d_model: int,
537
- nhead: int,
537
+ n_heads: int,
538
538
  num_layers: int = 6,
539
539
  dim_feedforward: int = 2048,
540
540
  dropout_p: float = 0.1,
@@ -546,7 +546,7 @@ class TransformerDecoder(eqx.Module):
546
546
  inference: bool = False,
547
547
  *,
548
548
  key: PRNGKeyArray,
549
- dtype: Any,
549
+ dtype: Any = None,
550
550
  ):
551
551
  if dtype is None:
552
552
  dtype = default_floating_dtype()
@@ -557,7 +557,7 @@ class TransformerDecoder(eqx.Module):
557
557
  self.layers = [
558
558
  TransformerDecoderLayer(
559
559
  d_model=d_model,
560
- nhead=nhead,
560
+ n_heads=n_heads,
561
561
  dim_feedforward=dim_feedforward,
562
562
  dropout_p=dropout_p,
563
563
  activation=activation,
@@ -627,7 +627,7 @@ class Transformer(eqx.Module):
627
627
  def __init__(
628
628
  self,
629
629
  d_model: int,
630
- nhead: int,
630
+ n_heads: int,
631
631
  num_encoder_layers: int = 6,
632
632
  num_decoder_layers: int = 6,
633
633
  dim_feedforward: int = 2048,
@@ -639,7 +639,7 @@ class Transformer(eqx.Module):
639
639
  inference: bool = False,
640
640
  *,
641
641
  key: PRNGKeyArray,
642
- dtype: Any,
642
+ dtype: Any = None,
643
643
  ):
644
644
  if dtype is None:
645
645
  dtype = default_floating_dtype()
@@ -650,7 +650,7 @@ class Transformer(eqx.Module):
650
650
 
651
651
  self.encoder = TransformerEncoder(
652
652
  d_model=d_model,
653
- nhead=nhead,
653
+ n_heads=n_heads,
654
654
  num_layers=num_encoder_layers,
655
655
  dim_feedforward=dim_feedforward,
656
656
  dropout_p=dropout_p,
@@ -666,7 +666,7 @@ class Transformer(eqx.Module):
666
666
 
667
667
  self.decoder = TransformerDecoder(
668
668
  d_model=d_model,
669
- nhead=nhead,
669
+ n_heads=n_heads,
670
670
  num_layers=num_decoder_layers,
671
671
  dim_feedforward=dim_feedforward,
672
672
  dropout_p=dropout_p,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxonlayers
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Additional layers and functions that extend Equinox
5
5
  Requires-Python: >=3.13
6
6
  Requires-Dist: beartype>=0.21.0
@@ -16,7 +16,7 @@ jaxonlayers/layers/normalization.py,sha256=3aGzNzDN05A72ZHLUM2w9WpicLtGsjzj1l0jh
16
16
  jaxonlayers/layers/regularization.py,sha256=ZrvtBJPH84xuxrxEbZc7TBxjp8OvKEv4ecan5s8F9zs,563
17
17
  jaxonlayers/layers/sequential.py,sha256=Tw98hNZiXMC-CYZD6h_pi7eAxkgHeQAUvZF2I9H0d8Y,2833
18
18
  jaxonlayers/layers/state_space.py,sha256=oDVRbduNtU48Q4rLd-XywZcqVN0QYTlq1UUhOXcGLoo,2537
19
- jaxonlayers/layers/transformer.py,sha256=u2nRLFp1cnM2WvAXQQ5zKyHbLFV0bUnaW91x56dl6n8,21781
20
- jaxonlayers-0.2.0.dist-info/METADATA,sha256=gTtPTaJzmBNNMZcTbe-gBVFT09GPjCGVUbkRVv_7hhc,565
21
- jaxonlayers-0.2.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
22
- jaxonlayers-0.2.0.dist-info/RECORD,,
19
+ jaxonlayers/layers/transformer.py,sha256=Syuh_kN-I-Gg-C20trQcNio1WEG88OHDRl3ShFaQNqQ,21848
20
+ jaxonlayers-0.2.1.dist-info/METADATA,sha256=rHoA2ZdpFRfOv-aHd4qkzY5a-nEDTiTHf5WplX3Mz_o,565
21
+ jaxonlayers-0.2.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
22
+ jaxonlayers-0.2.1.dist-info/RECORD,,