xax 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -11,7 +11,7 @@ This file can be maintained by running the update script:
11
11
  python -m scripts.update_api --inplace
12
12
  """
13
13
 
14
- __version__ = "0.0.6"
14
+ __version__ = "0.0.7"
15
15
 
16
16
  # This list shouldn't be modified by hand; instead, run the update script.
17
17
  __all__ = [
@@ -34,6 +34,8 @@ __all__ = [
34
34
  "get_positional_embeddings",
35
35
  "get_rotary_embeddings",
36
36
  "rotary_embeddings",
37
+ "euler_to_quat",
38
+ "quat_to_euler",
37
39
  "is_master",
38
40
  "BaseLauncher",
39
41
  "CliLauncher",
@@ -78,11 +80,18 @@ __all__ = [
78
80
  "save_config",
79
81
  "stage_environment",
80
82
  "to_markdown_table",
83
+ "jit",
81
84
  "ColoredFormatter",
82
85
  "configure_logging",
83
86
  "one_hot",
84
87
  "partial_flatten",
85
88
  "worker_chunk",
89
+ "profile",
90
+ "compute_nan_ratio",
91
+ "flatten_array",
92
+ "flatten_pytree",
93
+ "slice_array",
94
+ "slice_pytree",
86
95
  "TextBlock",
87
96
  "camelcase_to_snakecase",
88
97
  "colored",
@@ -142,6 +151,8 @@ NAME_MAP: dict[str, str] = {
142
151
  "get_positional_embeddings": "nn.embeddings",
143
152
  "get_rotary_embeddings": "nn.embeddings",
144
153
  "rotary_embeddings": "nn.embeddings",
154
+ "euler_to_quat": "nn.geom",
155
+ "quat_to_euler": "nn.geom",
145
156
  "is_master": "nn.parallel",
146
157
  "BaseLauncher": "task.launchers.base",
147
158
  "CliLauncher": "task.launchers.cli",
@@ -186,11 +197,18 @@ NAME_MAP: dict[str, str] = {
186
197
  "save_config": "utils.experiments",
187
198
  "stage_environment": "utils.experiments",
188
199
  "to_markdown_table": "utils.experiments",
200
+ "jit": "utils.jax",
189
201
  "ColoredFormatter": "utils.logging",
190
202
  "configure_logging": "utils.logging",
191
203
  "one_hot": "utils.numpy",
192
204
  "partial_flatten": "utils.numpy",
193
205
  "worker_chunk": "utils.numpy",
206
+ "profile": "utils.profile",
207
+ "compute_nan_ratio": "utils.pytree",
208
+ "flatten_array": "utils.pytree",
209
+ "flatten_pytree": "utils.pytree",
210
+ "slice_array": "utils.pytree",
211
+ "slice_pytree": "utils.pytree",
194
212
  "TextBlock": "utils.text",
195
213
  "camelcase_to_snakecase": "utils.text",
196
214
  "colored": "utils.text",
@@ -257,6 +275,7 @@ if IMPORT_ALL or TYPE_CHECKING:
257
275
  get_rotary_embeddings,
258
276
  rotary_embeddings,
259
277
  )
278
+ from xax.nn.geom import euler_to_quat, quat_to_euler
260
279
  from xax.nn.parallel import is_master
261
280
  from xax.task.base import RawConfigType
262
281
  from xax.task.launchers.base import BaseLauncher
@@ -299,6 +318,7 @@ if IMPORT_ALL or TYPE_CHECKING:
299
318
  stage_environment,
300
319
  to_markdown_table,
301
320
  )
321
+ from xax.utils.jax import jit
302
322
  from xax.utils.logging import (
303
323
  LOG_ERROR_SUMMARY,
304
324
  LOG_PING,
@@ -307,6 +327,14 @@ if IMPORT_ALL or TYPE_CHECKING:
307
327
  configure_logging,
308
328
  )
309
329
  from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
330
+ from xax.utils.profile import profile
331
+ from xax.utils.pytree import (
332
+ compute_nan_ratio,
333
+ flatten_array,
334
+ flatten_pytree,
335
+ slice_array,
336
+ slice_pytree,
337
+ )
310
338
  from xax.utils.text import (
311
339
  TextBlock,
312
340
  camelcase_to_snakecase,
xax/nn/geom.py ADDED
@@ -0,0 +1,75 @@
1
+ """Defines geometry functions."""
2
+
3
+ import jax
4
+ from jax import numpy as jnp
5
+
6
+
7
+ def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
8
+ """Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
9
+
10
+ Args:
11
+ quat_4: The quaternion to convert, shape (*, 4).
12
+ eps: A small epsilon value to avoid division by zero.
13
+
14
+ Returns:
15
+ The roll, pitch, yaw angles with shape (*, 3).
16
+ """
17
+ quat_4 = quat_4 / (jnp.linalg.norm(quat_4, axis=-1, keepdims=True) + eps)
18
+ w, x, y, z = jnp.split(quat_4, 4, axis=-1)
19
+
20
+ # Roll (x-axis rotation)
21
+ sinr_cosp = 2.0 * (w * x + y * z)
22
+ cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
23
+ roll = jnp.arctan2(sinr_cosp, cosr_cosp)
24
+
25
+ # Pitch (y-axis rotation)
26
+ sinp = 2.0 * (w * y - z * x)
27
+
28
+ # Handle edge cases where |sinp| >= 1
29
+ pitch = jnp.where(
30
+ jnp.abs(sinp) >= 1.0,
31
+ jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
32
+ jnp.arcsin(sinp),
33
+ )
34
+
35
+ # Yaw (z-axis rotation)
36
+ siny_cosp = 2.0 * (w * z + x * y)
37
+ cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
38
+ yaw = jnp.arctan2(siny_cosp, cosy_cosp)
39
+
40
+ return jnp.concatenate([roll, pitch, yaw], axis=-1)
41
+
42
+
43
+ def euler_to_quat(euler_3: jax.Array) -> jax.Array:
44
+ """Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
45
+
46
+ Args:
47
+ euler_3: The roll, pitch, yaw angles, shape (*, 3).
48
+
49
+ Returns:
50
+ The quaternion with shape (*, 4).
51
+ """
52
+ # Extract roll, pitch, yaw from input
53
+ roll, pitch, yaw = jnp.split(euler_3, 3, axis=-1)
54
+
55
+ # Calculate trigonometric functions for each angle
56
+ cr = jnp.cos(roll * 0.5)
57
+ sr = jnp.sin(roll * 0.5)
58
+ cp = jnp.cos(pitch * 0.5)
59
+ sp = jnp.sin(pitch * 0.5)
60
+ cy = jnp.cos(yaw * 0.5)
61
+ sy = jnp.sin(yaw * 0.5)
62
+
63
+ # Calculate quaternion components using the conversion formula
64
+ w = cr * cp * cy + sr * sp * sy
65
+ x = sr * cp * cy - cr * sp * sy
66
+ y = cr * sp * cy + sr * cp * sy
67
+ z = cr * cp * sy - sr * sp * cy
68
+
69
+ # Combine into quaternion [w, x, y, z]
70
+ quat = jnp.concatenate([w, x, y, z], axis=-1)
71
+
72
+ # Normalize the quaternion
73
+ quat = quat / jnp.linalg.norm(quat, axis=-1, keepdims=True)
74
+
75
+ return quat
xax/utils/jax.py CHANGED
@@ -1,14 +1,140 @@
1
1
  """Defines some utility functions for interfacing with Jax."""
2
2
 
3
+ import inspect
4
+ import logging
5
+ import os
6
+ import time
7
+ from functools import wraps
8
+ from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
9
+
10
+ import jax
3
11
  import jax.numpy as jnp
4
12
  import numpy as np
13
+ from jax._src import sharding_impls
14
+ from jax._src.lib import xla_client as xc
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ DEFAULT_COMPILE_TIMEOUT = 1.0
5
19
 
6
20
  Number = int | float | np.ndarray | jnp.ndarray
7
21
 
8
22
 
23
+ P = ParamSpec("P") # For function parameters
24
+ R = TypeVar("R") # For function return type
25
+
26
+
9
27
  def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
10
28
  if isinstance(value, (int, float)):
11
29
  return float(value)
12
30
  if isinstance(value, (np.ndarray, jnp.ndarray)):
13
31
  return float(value.item())
14
32
  raise TypeError(f"Unexpected type: {type(value)}")
33
+
34
+
35
+ def get_hash(obj: object) -> int:
36
+ """Get a hash of an object.
37
+
38
+ If the object is hashable, use the hash. Otherwise, use the id.
39
+ """
40
+ if hasattr(obj, "__hash__"):
41
+ return hash(obj)
42
+ return id(obj)
43
+
44
+
45
+ def jit(
46
+ in_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
47
+ out_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
48
+ static_argnums: int | Sequence[int] | None = None,
49
+ static_argnames: str | Iterable[str] | None = None,
50
+ donate_argnums: int | Sequence[int] | None = None,
51
+ donate_argnames: str | Iterable[str] | None = None,
52
+ keep_unused: bool = False,
53
+ device: xc.Device | None = None,
54
+ backend: str | None = None,
55
+ inline: bool = False,
56
+ abstracted_axes: Any | None = None, # noqa: ANN401
57
+ compiler_options: dict[str, Any] | None = None,
58
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
59
+ """Wrapper function that provides utility improvements over Jax's JIT.
60
+
61
+ Specifically, this function works on class methods, is toggleable, and
62
+ detects recompilations by matching hash values.
63
+
64
+ This is meant to be used as a decorator factory, and the decorated function
65
+ calls `wrapped`.
66
+ """
67
+
68
+ def decorator(fn: Callable[P, R]) -> Callable[P, R]:
69
+ class JitState:
70
+ compilation_count = 0
71
+ last_arg_dict: dict[str, int] | None = None
72
+
73
+ sig = inspect.signature(fn)
74
+ param_names = list(sig.parameters.keys())
75
+
76
+ jitted_fn = jax.jit(
77
+ fn,
78
+ in_shardings=in_shardings,
79
+ out_shardings=out_shardings,
80
+ static_argnums=static_argnums,
81
+ static_argnames=static_argnames,
82
+ donate_argnums=donate_argnums,
83
+ donate_argnames=donate_argnames,
84
+ keep_unused=keep_unused,
85
+ device=device,
86
+ backend=backend,
87
+ inline=inline,
88
+ abstracted_axes=abstracted_axes,
89
+ compiler_options=compiler_options,
90
+ )
91
+
92
+ @wraps(fn)
93
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
94
+ if os.environ.get("DEBUG", "0") == "1": # skipping during debug
95
+ return fn(*args, **kwargs)
96
+
97
+ do_profile = os.environ.get("JIT_PROFILE", "0") == "1"
98
+
99
+ if do_profile:
100
+ class_name = (args[0].__class__.__name__) + "." if fn.__name__ == "__call__" else ""
101
+ logger.info(
102
+ "Currently running %s (count: %s)",
103
+ f"{class_name}{fn.__name__}",
104
+ JitState.compilation_count,
105
+ )
106
+
107
+ start_time = time.time()
108
+ res = jitted_fn(*args, **kwargs)
109
+ end_time = time.time()
110
+ runtime = end_time - start_time
111
+
112
+ # if this is true, if runtime is higher than COMPILE_TIMEOUT, we recompile
113
+ # TODO: we should probably reimplement the lower-level jitting logic to avoid this
114
+ if do_profile:
115
+ arg_dict = {}
116
+ for i, arg in enumerate(args):
117
+ if i < len(param_names):
118
+ arg_dict[param_names[i]] = get_hash(arg)
119
+ for k, v in kwargs.items():
120
+ arg_dict[k] = get_hash(v)
121
+
122
+ logger.info("Hashing took %s seconds", runtime)
123
+ JitState.compilation_count += 1
124
+
125
+ if JitState.last_arg_dict is not None:
126
+ all_keys = set(arg_dict.keys()) | set(JitState.last_arg_dict.keys())
127
+ for k in all_keys:
128
+ prev = JitState.last_arg_dict.get(k, "N/A")
129
+ curr = arg_dict.get(k, "N/A")
130
+
131
+ if prev != curr:
132
+ logger.info("- Arg '%s' hash changed: %s -> %s", k, prev, curr)
133
+
134
+ JitState.last_arg_dict = arg_dict
135
+
136
+ return cast(R, res)
137
+
138
+ return wrapped
139
+
140
+ return decorator
xax/utils/profile.py ADDED
@@ -0,0 +1,61 @@
1
+ """Profiling utilities."""
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ from functools import wraps
7
+ from typing import Callable, ParamSpec, TypeVar
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ P = ParamSpec("P") # For function parameters
12
+ R = TypeVar("R") # For function return type
13
+
14
+
15
+ def profile(fn: Callable[P, R]) -> Callable[P, R]:
16
+ """Profiling decorator that tracks function call count and execution time.
17
+
18
+ Activated when the PROFILE environment variable is set to "1".
19
+
20
+ Returns:
21
+ A decorated function with profiling capabilities.
22
+ """
23
+
24
+ class ProfileState:
25
+ call_count = 0
26
+ total_time = 0.0
27
+
28
+ @wraps(fn)
29
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
30
+ if os.environ.get("PROFILE", "0") != "1":
31
+ return fn(*args, **kwargs)
32
+
33
+ start_time = time.time()
34
+ res = fn(*args, **kwargs)
35
+ end_time = time.time()
36
+ runtime = end_time - start_time
37
+
38
+ ProfileState.call_count += 1
39
+ ProfileState.total_time += runtime
40
+
41
+ # Handle class methods by showing class name
42
+ if fn.__name__ == "__call__" or (args and hasattr(args[0], "__class__")):
43
+ try:
44
+ class_name = args[0].__class__.__name__ + "."
45
+ except (IndexError, AttributeError):
46
+ class_name = ""
47
+ else:
48
+ class_name = ""
49
+
50
+ logger.info(
51
+ "%s %s - call #%s, took %s seconds, total: %s seconds",
52
+ class_name,
53
+ fn.__name__,
54
+ ProfileState.call_count,
55
+ runtime,
56
+ ProfileState.total_time,
57
+ )
58
+
59
+ return res
60
+
61
+ return wrapped
xax/utils/pytree.py ADDED
@@ -0,0 +1,50 @@
1
+ """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
+
3
+ import chex
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from jax import Array
7
+ from jaxtyping import PyTree
8
+
9
+
10
+ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
11
+ """Get a slice of an array along the first dimension.
12
+
13
+ For multi-dimensional arrays, this slices only along the first dimension
14
+ and keeps all other dimensions intact.
15
+ """
16
+ chex.assert_shape(start, ())
17
+ chex.assert_shape(slice_length, ())
18
+ start_indices = (start,) + (0,) * (len(x.shape) - 1)
19
+ slice_sizes = (slice_length,) + x.shape[1:]
20
+
21
+ return jax.lax.dynamic_slice(x, start_indices, slice_sizes)
22
+
23
+
24
+ def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
25
+ """Get a slice of a pytree."""
26
+ return jax.tree_util.tree_map(lambda x: slice_array(x, start, slice_length), pytree)
27
+
28
+
29
+ def flatten_array(x: Array, flatten_size: int) -> Array:
30
+ """Flatten an array into a (flatten_size, ...) array."""
31
+ reshaped = jnp.reshape(x, (flatten_size, *x.shape[2:]))
32
+ assert reshaped.shape[0] == flatten_size
33
+ return reshaped
34
+
35
+
36
+ def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
37
+ """Flatten a pytree into a (flatten_size, ...) pytree."""
38
+ return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
39
+
40
+
41
+ def compute_nan_ratio(pytree: PyTree) -> Array:
42
+ """Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
43
+ nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
44
+ total_counts = jax.tree_util.tree_map(lambda x: x.size, pytree)
45
+
46
+ total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
47
+ total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
48
+ overall_nan_ratio = jnp.array(total_nans / total_elements)
49
+
50
+ return overall_nan_ratio
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: xax
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: The xax project
5
5
  Home-page: https://github.com/dpshai/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=RTUsDh_R0TFa09q-_U0vd-eCYRC-bCaHqHlayp8U2hU,9736
1
+ xax/__init__.py,sha256=ScTkvKaxgpuKhhs9RINJa2XWCj899ndSYrB3FtScfxw,10509
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=NmU9PNJhfLtNqqtWWf8WqMjgbBPCn_yt8oMGAgS7Fno,291
@@ -8,6 +8,7 @@ xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
11
+ xax/nn/geom.py,sha256=MtVar9AdqrJQGIFxcIFHyFnV_fblf9Pc4kQT_gTQASI,2195
11
12
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
12
13
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
14
  xax/task/base.py,sha256=LHDmM2c_Ps5cGEzn_QUpmyInD7zJJm3Yt9eSeij2Vus,7297
@@ -38,15 +39,17 @@ xax/task/mixins/step_wrapper.py,sha256=DJw42mUGwgKx2tkeqatKR9_F4J8ug4wmxKMeJPmhc
38
39
  xax/task/mixins/train.py,sha256=dhGL_IuDaJy39BooYlO7JO-_EotKldtBhBplDGU_AnM,21745
39
40
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
41
  xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
41
- xax/utils/jax.py,sha256=VzEVB766UyH3_cgN6UP0FkCsDuGlYg5KJj8YJS4yYUk,439
42
+ xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
42
43
  xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
43
44
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
45
+ xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
46
+ xax/utils/pytree.py,sha256=Jwx6ErJfv1r2D23D4eKz1Hoo3mAJ0SEqC3EagZarWkw,1858
44
47
  xax/utils/tensorboard.py,sha256=oGq2E3Yr0z2xaACv2UOVt_CHEVc8fBxI8V1M99Fd34E,9742
45
48
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
46
49
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
50
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
48
- xax-0.0.6.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
49
- xax-0.0.6.dist-info/METADATA,sha256=YO2c2PUMWkH1ILfPhFWKK4Sodbo9qUpUOCIkm4aLHfg,1171
50
- xax-0.0.6.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
51
- xax-0.0.6.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
52
- xax-0.0.6.dist-info/RECORD,,
51
+ xax-0.0.7.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
52
+ xax-0.0.7.dist-info/METADATA,sha256=hE0KO4kYcN6Ed8iZ4649R5ENOUaQysBMW9vTh-94d4I,1171
53
+ xax-0.0.7.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
54
+ xax-0.0.7.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
55
+ xax-0.0.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (76.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
File without changes