jaxonlayers 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/PKG-INFO +1 -1
  2. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/sequential.py +1 -1
  3. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/transformer.py +17 -17
  4. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/pyproject.toml +1 -1
  5. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_transformer.py +30 -30
  6. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/.gitignore +0 -0
  7. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/.pre-commit-config.yaml +0 -0
  8. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/.python-version +0 -0
  9. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/README.md +0 -0
  10. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/__init__.py +0 -0
  11. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/__init__.py +0 -0
  12. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/attention.py +0 -0
  13. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/embedding.py +0 -0
  14. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/initialization.py +0 -0
  15. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/masking.py +0 -0
  16. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/normalization.py +0 -0
  17. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/regularization.py +0 -0
  18. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/state_space.py +0 -0
  19. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/functions/utils.py +0 -0
  20. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/__init__.py +0 -0
  21. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/abstract.py +0 -0
  22. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/attention.py +0 -0
  23. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/convolution.py +0 -0
  24. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/normalization.py +0 -0
  25. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/regularization.py +0 -0
  26. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/jaxonlayers/layers/state_space.py +0 -0
  27. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/__init__.py +0 -0
  28. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_attention.py +0 -0
  29. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_batch_norm.py +0 -0
  30. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_efficientnet_layers.py +0 -0
  31. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_layernorm.py +0 -0
  32. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_local_response_normalisation.py +0 -0
  33. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/tests/test_mha.py +0 -0
  34. {jaxonlayers-0.2.0 → jaxonlayers-0.2.2}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaxonlayers
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: Additional layers and functions that extend Equinox
5
5
  Requires-Python: >=3.13
6
6
  Requires-Dist: beartype>=0.21.0
@@ -56,7 +56,7 @@ class BatchedLinear(eqx.Module):
56
56
  self.use_bias = use_bias
57
57
 
58
58
  def __call__(
59
- self, x: Float[Array, "*batch in_features"]
59
+ self, x: Float[Array, "*batch in_features"], key=None
60
60
  ) -> Float[Array, "*batch out_features"]:
61
61
  input_shape = x.shape
62
62
 
@@ -40,7 +40,7 @@ class TransformerEncoderLayer(eqx.Module):
40
40
  def __init__(
41
41
  self,
42
42
  d_model: int,
43
- nhead: int,
43
+ n_heads: int,
44
44
  dim_feedforward: int = 2048,
45
45
  dropout_p: float = 0.1,
46
46
  activation: Callable = jax.nn.relu,
@@ -50,7 +50,7 @@ class TransformerEncoderLayer(eqx.Module):
50
50
  inference: bool = False,
51
51
  *,
52
52
  key: PRNGKeyArray,
53
- dtype: Any,
53
+ dtype: Any = None,
54
54
  ):
55
55
  if dtype is None:
56
56
  dtype = default_floating_dtype()
@@ -58,7 +58,7 @@ class TransformerEncoderLayer(eqx.Module):
58
58
  self.inference = inference
59
59
  mha_key, lin1_key, lin2_key = jax.random.split(key, 3)
60
60
  self.self_attn = eqx.nn.MultiheadAttention(
61
- nhead,
61
+ n_heads,
62
62
  d_model,
63
63
  dropout_p=dropout_p,
64
64
  use_query_bias=use_bias,
@@ -208,7 +208,7 @@ class TransformerDecoderLayer(eqx.Module):
208
208
  def __init__(
209
209
  self,
210
210
  d_model: int,
211
- nhead: int,
211
+ n_heads: int,
212
212
  dim_feedforward: int = 2048,
213
213
  dropout_p: float = 0.1,
214
214
  activation: Callable = jax.nn.relu,
@@ -218,7 +218,7 @@ class TransformerDecoderLayer(eqx.Module):
218
218
  inference: bool = False,
219
219
  *,
220
220
  key: PRNGKeyArray,
221
- dtype: Any,
221
+ dtype: Any = None,
222
222
  ):
223
223
  if dtype is None:
224
224
  dtype = default_floating_dtype()
@@ -227,7 +227,7 @@ class TransformerDecoderLayer(eqx.Module):
227
227
 
228
228
  mha_key1, mha_key2, lin1_key, lin2_key = jax.random.split(key, 4)
229
229
  self.self_attn = eqx.nn.MultiheadAttention(
230
- nhead,
230
+ n_heads,
231
231
  d_model,
232
232
  dropout_p=dropout_p,
233
233
  use_query_bias=use_bias,
@@ -239,7 +239,7 @@ class TransformerDecoderLayer(eqx.Module):
239
239
  dtype=dtype,
240
240
  )
241
241
  self.multihead_attn = eqx.nn.MultiheadAttention(
242
- nhead,
242
+ n_heads,
243
243
  d_model,
244
244
  dropout_p=dropout_p,
245
245
  use_query_bias=use_bias,
@@ -455,7 +455,7 @@ class TransformerEncoder(eqx.Module):
455
455
  def __init__(
456
456
  self,
457
457
  d_model: int,
458
- nhead: int,
458
+ n_heads: int,
459
459
  num_layers: int = 6,
460
460
  dim_feedforward: int = 2048,
461
461
  dropout_p: float = 0.1,
@@ -467,7 +467,7 @@ class TransformerEncoder(eqx.Module):
467
467
  inference: bool = False,
468
468
  *,
469
469
  key: PRNGKeyArray,
470
- dtype: Any,
470
+ dtype: Any = None,
471
471
  ):
472
472
  if dtype is None:
473
473
  dtype = default_floating_dtype()
@@ -478,7 +478,7 @@ class TransformerEncoder(eqx.Module):
478
478
  self.layers = [
479
479
  TransformerEncoderLayer(
480
480
  d_model=d_model,
481
- nhead=nhead,
481
+ n_heads=n_heads,
482
482
  dim_feedforward=dim_feedforward,
483
483
  dropout_p=dropout_p,
484
484
  activation=activation,
@@ -534,7 +534,7 @@ class TransformerDecoder(eqx.Module):
534
534
  def __init__(
535
535
  self,
536
536
  d_model: int,
537
- nhead: int,
537
+ n_heads: int,
538
538
  num_layers: int = 6,
539
539
  dim_feedforward: int = 2048,
540
540
  dropout_p: float = 0.1,
@@ -546,7 +546,7 @@ class TransformerDecoder(eqx.Module):
546
546
  inference: bool = False,
547
547
  *,
548
548
  key: PRNGKeyArray,
549
- dtype: Any,
549
+ dtype: Any = None,
550
550
  ):
551
551
  if dtype is None:
552
552
  dtype = default_floating_dtype()
@@ -557,7 +557,7 @@ class TransformerDecoder(eqx.Module):
557
557
  self.layers = [
558
558
  TransformerDecoderLayer(
559
559
  d_model=d_model,
560
- nhead=nhead,
560
+ n_heads=n_heads,
561
561
  dim_feedforward=dim_feedforward,
562
562
  dropout_p=dropout_p,
563
563
  activation=activation,
@@ -627,7 +627,7 @@ class Transformer(eqx.Module):
627
627
  def __init__(
628
628
  self,
629
629
  d_model: int,
630
- nhead: int,
630
+ n_heads: int,
631
631
  num_encoder_layers: int = 6,
632
632
  num_decoder_layers: int = 6,
633
633
  dim_feedforward: int = 2048,
@@ -639,7 +639,7 @@ class Transformer(eqx.Module):
639
639
  inference: bool = False,
640
640
  *,
641
641
  key: PRNGKeyArray,
642
- dtype: Any,
642
+ dtype: Any = None,
643
643
  ):
644
644
  if dtype is None:
645
645
  dtype = default_floating_dtype()
@@ -650,7 +650,7 @@ class Transformer(eqx.Module):
650
650
 
651
651
  self.encoder = TransformerEncoder(
652
652
  d_model=d_model,
653
- nhead=nhead,
653
+ n_heads=n_heads,
654
654
  num_layers=num_encoder_layers,
655
655
  dim_feedforward=dim_feedforward,
656
656
  dropout_p=dropout_p,
@@ -666,7 +666,7 @@ class Transformer(eqx.Module):
666
666
 
667
667
  self.decoder = TransformerDecoder(
668
668
  d_model=d_model,
669
- nhead=nhead,
669
+ n_heads=n_heads,
670
670
  num_layers=num_decoder_layers,
671
671
  dim_feedforward=dim_feedforward,
672
672
  dropout_p=dropout_p,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "jaxonlayers"
3
- version = "0.2.0"
3
+ version = "0.2.2"
4
4
  description = "Additional layers and functions that extend Equinox"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.13"
@@ -25,12 +25,12 @@ class TestTransformerEncoderLayer:
25
25
  )
26
26
  def test_masking(self, is_causal, use_explicit_mask):
27
27
  d_model = 64
28
- nhead = 4
28
+ n_heads = 4
29
29
  seq_len = 10
30
30
 
31
31
  layer = TransformerEncoderLayer(
32
32
  d_model=d_model,
33
- nhead=nhead,
33
+ n_heads=n_heads,
34
34
  key=jax.random.key(0),
35
35
  dtype=jnp.float32,
36
36
  )
@@ -43,12 +43,12 @@ class TestTransformerEncoderLayer:
43
43
 
44
44
  def test_jit_no_retrace(self):
45
45
  d_model = 64
46
- nhead = 4
46
+ n_heads = 4
47
47
  seq_len = 10
48
48
 
49
49
  layer = TransformerEncoderLayer(
50
50
  d_model=d_model,
51
- nhead=nhead,
51
+ n_heads=n_heads,
52
52
  key=jax.random.key(0),
53
53
  dtype=jnp.float32,
54
54
  )
@@ -84,13 +84,13 @@ class TestTransformerDecoderLayer:
84
84
  self, tgt_is_causal, memory_is_causal, use_tgt_mask, use_memory_mask
85
85
  ):
86
86
  d_model = 64
87
- nhead = 4
87
+ n_heads = 4
88
88
  tgt_len = 10
89
89
  src_len = 12
90
90
 
91
91
  layer = TransformerDecoderLayer(
92
92
  d_model=d_model,
93
- nhead=nhead,
93
+ n_heads=n_heads,
94
94
  key=jax.random.key(0),
95
95
  dtype=jnp.float32,
96
96
  )
@@ -116,13 +116,13 @@ class TestTransformerDecoderLayer:
116
116
 
117
117
  def test_with_process_heads(self):
118
118
  d_model = 64
119
- nhead = 4
119
+ n_heads = 4
120
120
  tgt_len = 10
121
121
  src_len = 12
122
122
 
123
123
  layer = TransformerDecoderLayer(
124
124
  d_model=d_model,
125
- nhead=nhead,
125
+ n_heads=n_heads,
126
126
  key=jax.random.key(0),
127
127
  dtype=jnp.float32,
128
128
  )
@@ -144,13 +144,13 @@ class TestTransformerDecoderLayer:
144
144
 
145
145
  def test_jit_no_retrace(self):
146
146
  d_model = 64
147
- nhead = 4
147
+ n_heads = 4
148
148
  tgt_len = 10
149
149
  src_len = 12
150
150
 
151
151
  layer = TransformerDecoderLayer(
152
152
  d_model=d_model,
153
- nhead=nhead,
153
+ n_heads=n_heads,
154
154
  key=jax.random.key(0),
155
155
  dtype=jnp.float32,
156
156
  )
@@ -184,13 +184,13 @@ class TestTransformerEncoder:
184
184
  )
185
185
  def test_masking(self, is_causal, use_explicit_mask):
186
186
  d_model = 64
187
- nhead = 4
187
+ n_heads = 4
188
188
  num_layers = 3
189
189
  seq_len = 10
190
190
 
191
191
  encoder = TransformerEncoder(
192
192
  d_model=d_model,
193
- nhead=nhead,
193
+ n_heads=n_heads,
194
194
  num_layers=num_layers,
195
195
  key=jax.random.key(0),
196
196
  dtype=jnp.float32,
@@ -204,13 +204,13 @@ class TestTransformerEncoder:
204
204
 
205
205
  def test_with_process_heads(self):
206
206
  d_model = 64
207
- nhead = 4
207
+ n_heads = 4
208
208
  num_layers = 3
209
209
  seq_len = 10
210
210
 
211
211
  encoder = TransformerEncoder(
212
212
  d_model=d_model,
213
- nhead=nhead,
213
+ n_heads=n_heads,
214
214
  num_layers=num_layers,
215
215
  key=jax.random.key(0),
216
216
  dtype=jnp.float32,
@@ -226,13 +226,13 @@ class TestTransformerEncoder:
226
226
 
227
227
  def test_jit_no_retrace(self):
228
228
  d_model = 64
229
- nhead = 4
229
+ n_heads = 4
230
230
  num_layers = 3
231
231
  seq_len = 10
232
232
 
233
233
  encoder = TransformerEncoder(
234
234
  d_model=d_model,
235
- nhead=nhead,
235
+ n_heads=n_heads,
236
236
  num_layers=num_layers,
237
237
  key=jax.random.key(0),
238
238
  dtype=jnp.float32,
@@ -269,14 +269,14 @@ class TestTransformerDecoder:
269
269
  self, tgt_is_causal, memory_is_causal, use_tgt_mask, use_memory_mask
270
270
  ):
271
271
  d_model = 64
272
- nhead = 4
272
+ n_heads = 4
273
273
  num_layers = 3
274
274
  tgt_len = 10
275
275
  src_len = 12
276
276
 
277
277
  decoder = TransformerDecoder(
278
278
  d_model=d_model,
279
- nhead=nhead,
279
+ n_heads=n_heads,
280
280
  num_layers=num_layers,
281
281
  key=jax.random.key(0),
282
282
  dtype=jnp.float32,
@@ -303,14 +303,14 @@ class TestTransformerDecoder:
303
303
 
304
304
  def test_with_process_heads(self):
305
305
  d_model = 64
306
- nhead = 4
306
+ n_heads = 4
307
307
  num_layers = 3
308
308
  tgt_len = 10
309
309
  src_len = 12
310
310
 
311
311
  decoder = TransformerDecoder(
312
312
  d_model=d_model,
313
- nhead=nhead,
313
+ n_heads=n_heads,
314
314
  num_layers=num_layers,
315
315
  key=jax.random.key(0),
316
316
  dtype=jnp.float32,
@@ -333,14 +333,14 @@ class TestTransformerDecoder:
333
333
 
334
334
  def test_jit_no_retrace(self):
335
335
  d_model = 64
336
- nhead = 4
336
+ n_heads = 4
337
337
  num_layers = 3
338
338
  tgt_len = 10
339
339
  src_len = 12
340
340
 
341
341
  decoder = TransformerDecoder(
342
342
  d_model=d_model,
343
- nhead=nhead,
343
+ n_heads=n_heads,
344
344
  num_layers=num_layers,
345
345
  key=jax.random.key(0),
346
346
  dtype=jnp.float32,
@@ -383,13 +383,13 @@ class TestTransformer:
383
383
  use_memory_mask,
384
384
  ):
385
385
  d_model = 64
386
- nhead = 4
386
+ n_heads = 4
387
387
  src_len = 12
388
388
  tgt_len = 10
389
389
 
390
390
  transformer = Transformer(
391
391
  d_model=d_model,
392
- nhead=nhead,
392
+ n_heads=n_heads,
393
393
  key=jax.random.key(0),
394
394
  dtype=jnp.float32,
395
395
  )
@@ -418,13 +418,13 @@ class TestTransformer:
418
418
 
419
419
  def test_with_process_heads(self):
420
420
  d_model = 64
421
- nhead = 4
421
+ n_heads = 4
422
422
  src_len = 12
423
423
  tgt_len = 10
424
424
 
425
425
  transformer = Transformer(
426
426
  d_model=d_model,
427
- nhead=nhead,
427
+ n_heads=n_heads,
428
428
  key=jax.random.key(0),
429
429
  dtype=jnp.float32,
430
430
  )
@@ -455,13 +455,13 @@ class TestTransformer:
455
455
  )
456
456
  def test_activations(self, activation):
457
457
  d_model = 64
458
- nhead = 4
458
+ n_heads = 4
459
459
  src_len = 12
460
460
  tgt_len = 10
461
461
 
462
462
  transformer = Transformer(
463
463
  d_model=d_model,
464
- nhead=nhead,
464
+ n_heads=n_heads,
465
465
  activation=activation,
466
466
  key=jax.random.key(0),
467
467
  dtype=jnp.float32,
@@ -475,13 +475,13 @@ class TestTransformer:
475
475
 
476
476
  def test_jit_no_retrace(self):
477
477
  d_model = 64
478
- nhead = 4
478
+ n_heads = 4
479
479
  src_len = 12
480
480
  tgt_len = 10
481
481
 
482
482
  transformer = Transformer(
483
483
  d_model=d_model,
484
- nhead=nhead,
484
+ n_heads=n_heads,
485
485
  key=jax.random.key(0),
486
486
  dtype=jnp.float32,
487
487
  )
File without changes
File without changes
File without changes
File without changes