xax 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/nn/embeddings.py CHANGED
@@ -33,10 +33,10 @@ class LearnedPositionalEmbeddings(eqx.Module):
33
33
  learnable: Whether the embeddings are learnable.
34
34
  """
35
35
 
36
- max_tsz: int = eqx.field(static=True)
37
- embed_dim: int = eqx.field(static=True)
38
- learnable: bool = eqx.field(static=True)
39
- embeddings_tc: Array
36
+ max_tsz: int = eqx.field()
37
+ embed_dim: int = eqx.field()
38
+ learnable: bool = eqx.field()
39
+ embeddings_tc: Array = eqx.field()
40
40
 
41
41
  def __init__(
42
42
  self,
@@ -74,10 +74,10 @@ class SinusoidalEmbeddings(eqx.Module):
74
74
  base: The base for the sinusoidal embeddings.
75
75
  """
76
76
 
77
- base: int = eqx.field(static=True)
78
- max_tsz: int | None = eqx.field(static=True)
79
- embed_dim: int | None = eqx.field(static=True)
80
- embeddings_tc: Array | None
77
+ base: int = eqx.field()
78
+ max_tsz: int | None = eqx.field()
79
+ embed_dim: int | None = eqx.field()
80
+ embeddings_tc: Array | None = eqx.field()
81
81
 
82
82
  def __init__(
83
83
  self,
@@ -91,8 +91,8 @@ class SinusoidalEmbeddings(eqx.Module):
91
91
  self.max_tsz = max_tsz
92
92
  self.embed_dim = embed_dim
93
93
  self.base = base
94
+ self.embeddings_tc = None
94
95
 
95
- self.embeddings_tc: Array | None = None
96
96
  if learnable:
97
97
  assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
98
98
  assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
@@ -192,7 +192,7 @@ class RotaryEmbeddings(eqx.Module):
192
192
  base: The base for the sinusoidal embeddings.
193
193
  """
194
194
 
195
- base: int = eqx.field(static=True)
195
+ base: int = eqx.field()
196
196
 
197
197
  def __init__(self, base: int = 10_000) -> None:
198
198
  """Defines a rotary embeddings module.
xax/nn/geom.py CHANGED
@@ -207,7 +207,7 @@ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
207
207
 
208
208
  def normalize(v: jnp.ndarray, axis: int = -1, eps: float = 1e-8) -> jnp.ndarray:
209
209
  norm = jnp.linalg.norm(v, axis=axis, keepdims=True)
210
- return v / jnp.clip(norm, a_min=eps)
210
+ return v / jnp.clip(norm, min=eps)
211
211
 
212
212
 
213
213
  def rotation6d_to_rotation_matrix(r6d: jnp.ndarray) -> jnp.ndarray:
@@ -299,28 +299,28 @@ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
299
299
  trace = m00 + m11 + m22
300
300
 
301
301
  # Case 0: trace is positive
302
- s0 = jnp.sqrt(jnp.clip(trace + 1.0, a_min=0.0)) * 2.0 # S = 4 * qw
302
+ s0 = jnp.sqrt(jnp.clip(trace + 1.0, min=0.0)) * 2.0 # S = 4 * qw
303
303
  w0 = 0.25 * s0
304
304
  x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
305
305
  y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
306
306
  z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
307
307
 
308
308
  # Case 1: m00 is the largest diagonal term
309
- s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, a_min=0.0)) * 2.0 # S = 4 * qx
309
+ s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, min=0.0)) * 2.0 # S = 4 * qx
310
310
  w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
311
311
  x1 = 0.25 * s1
312
312
  y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
313
313
  z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
314
314
 
315
315
  # Case 2: m11 is the largest diagonal term
316
- s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, a_min=0.0)) * 2.0 # S = 4 * qy
316
+ s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, min=0.0)) * 2.0 # S = 4 * qy
317
317
  w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
318
318
  x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
319
319
  y2 = 0.25 * s2
320
320
  z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
321
321
 
322
322
  # Case 3: m22 is the largest diagonal term
323
- s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, a_min=0.0)) * 2.0 # S = 4 * qz
323
+ s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, min=0.0)) * 2.0 # S = 4 * qz
324
324
  w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
325
325
  x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
326
326
  y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
xax/nn/ssm.py CHANGED
@@ -222,12 +222,12 @@ class DiscreteDiagSSMBlock(DiagSSMBlock):
222
222
 
223
223
 
224
224
  class SSM(eqx.Module):
225
- vocab_embedding: eqx.nn.Embedding
226
- output_layer: eqx.nn.Linear
227
- blocks: list[BaseSSMBlock]
228
- num_layers: int = eqx.static_field()
229
- hidden_size: int = eqx.static_field()
230
- skip_connections: bool = eqx.static_field()
225
+ vocab_embedding: eqx.nn.Embedding = eqx.field()
226
+ output_layer: eqx.nn.Linear = eqx.field()
227
+ blocks: list[BaseSSMBlock] = eqx.field()
228
+ num_layers: int = eqx.field()
229
+ hidden_size: int = eqx.field()
230
+ skip_connections: bool = eqx.field()
231
231
 
232
232
  def __init__(
233
233
  self,
xax/task/mixins/train.py CHANGED
@@ -40,7 +40,12 @@ from xax.core.state import Phase, State
40
40
  from xax.nn.functions import set_random_seed
41
41
  from xax.nn.parallel import is_master
42
42
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
43
- from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart, load_ckpt
43
+ from xax.task.mixins.checkpointing import (
44
+ CheckpointingConfig,
45
+ CheckpointingMixin,
46
+ CheckpointPart,
47
+ load_ckpt,
48
+ )
44
49
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
45
50
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
46
51
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.3
3
+ Version: 0.3.5
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=ffVd9_qSVuEAIPn6eK_6N8qEfDALRJigArHZQGy6y1o,15819
1
+ xax/__init__.py,sha256=OeW6UObyosw6eJSEQ96AfRJIKHg5WyZ6xuZLJdcR6cg,16240
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -8,14 +8,14 @@ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
9
9
  xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
10
10
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- xax/nn/attention.py,sha256=0essK90OO3x9FxnUqU0DhufwXKRMN41zMtRCki5iKzQ,24742
12
- xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
11
+ xax/nn/attention.py,sha256=ESO6THJ5ORKxSM8LziRLEkj1d_QXtDndPi80Puyo-xA,28033
12
+ xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
13
13
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
14
- xax/nn/geom.py,sha256=6rBQrZRX1miG08VG-s8phPjA6MEFxUAfQVPt5F0RQQI,10645
14
+ xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
15
15
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
16
16
  xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
17
17
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
18
- xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
+ xax/nn/ssm.py,sha256=qSBv_FobnaFA5jt87OF5P2q5ih6sj4SlehhEhEFaPjA,10766
19
19
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  xax/task/base.py,sha256=i6FRJ75aqlekWkzJNRWDUEX7P514pUjLVuxjhX1GBgw,8198
21
21
  xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
@@ -42,7 +42,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
42
42
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
43
43
  xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
44
44
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
45
- xax/task/mixins/train.py,sha256=TZatz5QwTfrNhQTiO2IqrmQY9P4Lay6FAD2VsQpWa54,33245
45
+ xax/task/mixins/train.py,sha256=bjBoigTCjbq9H4hcqIO32irHBc9rC2zkgXrnGNI2RtI,33266
46
46
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
47
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
48
48
  xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
59
59
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
60
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
61
61
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
62
- xax-0.3.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
- xax-0.3.3.dist-info/METADATA,sha256=mjIzoFZDSR3V1-2LHbvup6wDVa4vLiqbqiNWLsKCXY8,1246
64
- xax-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
- xax-0.3.3.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
- xax-0.3.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
- xax-0.3.3.dist-info/RECORD,,
62
+ xax-0.3.5.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
+ xax-0.3.5.dist-info/METADATA,sha256=kMRKGih6o7SfqGrvGQW_7OkFST6PDnbPuopnfx_bAOs,1246
64
+ xax-0.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
+ xax-0.3.5.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
+ xax-0.3.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
+ xax-0.3.5.dist-info/RECORD,,
File without changes