xax 0.3.3__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +23 -8
- xax/nn/attention.py +519 -408
- xax/nn/embeddings.py +10 -10
- xax/nn/geom.py +5 -5
- xax/nn/ssm.py +6 -6
- xax/task/mixins/train.py +6 -1
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/METADATA +1 -1
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/RECORD +12 -12
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/WHEEL +0 -0
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/entry_points.txt +0 -0
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.3.dist-info → xax-0.3.5.dist-info}/top_level.txt +0 -0
xax/nn/embeddings.py
CHANGED
@@ -33,10 +33,10 @@ class LearnedPositionalEmbeddings(eqx.Module):
|
|
33
33
|
learnable: Whether the embeddings are learnable.
|
34
34
|
"""
|
35
35
|
|
36
|
-
max_tsz: int = eqx.field(
|
37
|
-
embed_dim: int = eqx.field(
|
38
|
-
learnable: bool = eqx.field(
|
39
|
-
embeddings_tc: Array
|
36
|
+
max_tsz: int = eqx.field()
|
37
|
+
embed_dim: int = eqx.field()
|
38
|
+
learnable: bool = eqx.field()
|
39
|
+
embeddings_tc: Array = eqx.field()
|
40
40
|
|
41
41
|
def __init__(
|
42
42
|
self,
|
@@ -74,10 +74,10 @@ class SinusoidalEmbeddings(eqx.Module):
|
|
74
74
|
base: The base for the sinusoidal embeddings.
|
75
75
|
"""
|
76
76
|
|
77
|
-
base: int = eqx.field(
|
78
|
-
max_tsz: int | None = eqx.field(
|
79
|
-
embed_dim: int | None = eqx.field(
|
80
|
-
embeddings_tc: Array | None
|
77
|
+
base: int = eqx.field()
|
78
|
+
max_tsz: int | None = eqx.field()
|
79
|
+
embed_dim: int | None = eqx.field()
|
80
|
+
embeddings_tc: Array | None = eqx.field()
|
81
81
|
|
82
82
|
def __init__(
|
83
83
|
self,
|
@@ -91,8 +91,8 @@ class SinusoidalEmbeddings(eqx.Module):
|
|
91
91
|
self.max_tsz = max_tsz
|
92
92
|
self.embed_dim = embed_dim
|
93
93
|
self.base = base
|
94
|
+
self.embeddings_tc = None
|
94
95
|
|
95
|
-
self.embeddings_tc: Array | None = None
|
96
96
|
if learnable:
|
97
97
|
assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
|
98
98
|
assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
|
@@ -192,7 +192,7 @@ class RotaryEmbeddings(eqx.Module):
|
|
192
192
|
base: The base for the sinusoidal embeddings.
|
193
193
|
"""
|
194
194
|
|
195
|
-
base: int = eqx.field(
|
195
|
+
base: int = eqx.field()
|
196
196
|
|
197
197
|
def __init__(self, base: int = 10_000) -> None:
|
198
198
|
"""Defines a rotary embeddings module.
|
xax/nn/geom.py
CHANGED
@@ -207,7 +207,7 @@ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
|
|
207
207
|
|
208
208
|
def normalize(v: jnp.ndarray, axis: int = -1, eps: float = 1e-8) -> jnp.ndarray:
|
209
209
|
norm = jnp.linalg.norm(v, axis=axis, keepdims=True)
|
210
|
-
return v / jnp.clip(norm,
|
210
|
+
return v / jnp.clip(norm, min=eps)
|
211
211
|
|
212
212
|
|
213
213
|
def rotation6d_to_rotation_matrix(r6d: jnp.ndarray) -> jnp.ndarray:
|
@@ -299,28 +299,28 @@ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
|
|
299
299
|
trace = m00 + m11 + m22
|
300
300
|
|
301
301
|
# Case 0: trace is positive
|
302
|
-
s0 = jnp.sqrt(jnp.clip(trace + 1.0,
|
302
|
+
s0 = jnp.sqrt(jnp.clip(trace + 1.0, min=0.0)) * 2.0 # S = 4 * qw
|
303
303
|
w0 = 0.25 * s0
|
304
304
|
x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
|
305
305
|
y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
|
306
306
|
z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
|
307
307
|
|
308
308
|
# Case 1: m00 is the largest diagonal term
|
309
|
-
s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22,
|
309
|
+
s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, min=0.0)) * 2.0 # S = 4 * qx
|
310
310
|
w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
|
311
311
|
x1 = 0.25 * s1
|
312
312
|
y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
|
313
313
|
z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
|
314
314
|
|
315
315
|
# Case 2: m11 is the largest diagonal term
|
316
|
-
s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22,
|
316
|
+
s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, min=0.0)) * 2.0 # S = 4 * qy
|
317
317
|
w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
|
318
318
|
x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
|
319
319
|
y2 = 0.25 * s2
|
320
320
|
z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
|
321
321
|
|
322
322
|
# Case 3: m22 is the largest diagonal term
|
323
|
-
s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11,
|
323
|
+
s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, min=0.0)) * 2.0 # S = 4 * qz
|
324
324
|
w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
|
325
325
|
x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
|
326
326
|
y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
|
xax/nn/ssm.py
CHANGED
@@ -222,12 +222,12 @@ class DiscreteDiagSSMBlock(DiagSSMBlock):
|
|
222
222
|
|
223
223
|
|
224
224
|
class SSM(eqx.Module):
|
225
|
-
vocab_embedding: eqx.nn.Embedding
|
226
|
-
output_layer: eqx.nn.Linear
|
227
|
-
blocks: list[BaseSSMBlock]
|
228
|
-
num_layers: int = eqx.
|
229
|
-
hidden_size: int = eqx.
|
230
|
-
skip_connections: bool = eqx.
|
225
|
+
vocab_embedding: eqx.nn.Embedding = eqx.field()
|
226
|
+
output_layer: eqx.nn.Linear = eqx.field()
|
227
|
+
blocks: list[BaseSSMBlock] = eqx.field()
|
228
|
+
num_layers: int = eqx.field()
|
229
|
+
hidden_size: int = eqx.field()
|
230
|
+
skip_connections: bool = eqx.field()
|
231
231
|
|
232
232
|
def __init__(
|
233
233
|
self,
|
xax/task/mixins/train.py
CHANGED
@@ -40,7 +40,12 @@ from xax.core.state import Phase, State
|
|
40
40
|
from xax.nn.functions import set_random_seed
|
41
41
|
from xax.nn.parallel import is_master
|
42
42
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
43
|
-
from xax.task.mixins.checkpointing import
|
43
|
+
from xax.task.mixins.checkpointing import (
|
44
|
+
CheckpointingConfig,
|
45
|
+
CheckpointingMixin,
|
46
|
+
CheckpointPart,
|
47
|
+
load_ckpt,
|
48
|
+
)
|
44
49
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
45
50
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
46
51
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=OeW6UObyosw6eJSEQ96AfRJIKHg5WyZ6xuZLJdcR6cg,16240
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -8,14 +8,14 @@ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
9
9
|
xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
xax/nn/attention.py,sha256=
|
12
|
-
xax/nn/embeddings.py,sha256=
|
11
|
+
xax/nn/attention.py,sha256=ESO6THJ5ORKxSM8LziRLEkj1d_QXtDndPi80Puyo-xA,28033
|
12
|
+
xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
|
13
13
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
14
|
-
xax/nn/geom.py,sha256=
|
14
|
+
xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
|
15
15
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
16
16
|
xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
|
17
17
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
18
|
-
xax/nn/ssm.py,sha256=
|
18
|
+
xax/nn/ssm.py,sha256=qSBv_FobnaFA5jt87OF5P2q5ih6sj4SlehhEhEFaPjA,10766
|
19
19
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
20
20
|
xax/task/base.py,sha256=i6FRJ75aqlekWkzJNRWDUEX7P514pUjLVuxjhX1GBgw,8198
|
21
21
|
xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
|
@@ -42,7 +42,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
42
42
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
43
43
|
xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
|
44
44
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
45
|
-
xax/task/mixins/train.py,sha256=
|
45
|
+
xax/task/mixins/train.py,sha256=bjBoigTCjbq9H4hcqIO32irHBc9rC2zkgXrnGNI2RtI,33266
|
46
46
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
47
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
48
48
|
xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
|
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
59
59
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
60
60
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
61
61
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
62
|
-
xax-0.3.
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
62
|
+
xax-0.3.5.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
63
|
+
xax-0.3.5.dist-info/METADATA,sha256=kMRKGih6o7SfqGrvGQW_7OkFST6PDnbPuopnfx_bAOs,1246
|
64
|
+
xax-0.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
65
|
+
xax-0.3.5.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
66
|
+
xax-0.3.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
67
|
+
xax-0.3.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|