xax 0.0.6__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.0.6/xax.egg-info → xax-0.0.7}/PKG-INFO +1 -1
- {xax-0.0.6 → xax-0.0.7}/xax/__init__.py +29 -1
- xax-0.0.7/xax/nn/geom.py +75 -0
- xax-0.0.7/xax/utils/jax.py +140 -0
- xax-0.0.7/xax/utils/profile.py +61 -0
- xax-0.0.7/xax/utils/pytree.py +50 -0
- {xax-0.0.6 → xax-0.0.7/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.0.6 → xax-0.0.7}/xax.egg-info/SOURCES.txt +3 -0
- xax-0.0.6/xax/utils/jax.py +0 -14
- {xax-0.0.6 → xax-0.0.7}/LICENSE +0 -0
- {xax-0.0.6 → xax-0.0.7}/MANIFEST.in +0 -0
- {xax-0.0.6 → xax-0.0.7}/README.md +0 -0
- {xax-0.0.6 → xax-0.0.7}/pyproject.toml +0 -0
- {xax-0.0.6 → xax-0.0.7}/setup.cfg +0 -0
- {xax-0.0.6 → xax-0.0.7}/setup.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/core/__init__.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/core/conf.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/core/state.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/nn/__init__.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/nn/embeddings.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/nn/functions.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/nn/parallel.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/py.typed +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/requirements-dev.txt +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/requirements.txt +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/__init__.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/base.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/launchers/__init__.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/launchers/base.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/launchers/cli.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/launchers/single_process.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/logger.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/loggers/__init__.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/loggers/callback.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/loggers/json.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/loggers/state.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/loggers/stdout.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/__init__.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/compile.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/logger.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/process.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/runnable.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/mixins/train.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/script.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/task/task.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/utils/__init__.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/utils/data/__init__.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/utils/data/collate.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/utils/experiments.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/utils/logging.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/utils/numpy.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/utils/tensorboard.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax/utils/text.py +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax.egg-info/requires.txt +0 -0
- {xax-0.0.6 → xax-0.0.7}/xax.egg-info/top_level.txt +0 -0
@@ -11,7 +11,7 @@ This file can be maintained by running the update script:
|
|
11
11
|
python -m scripts.update_api --inplace
|
12
12
|
"""
|
13
13
|
|
14
|
-
__version__ = "0.0.
|
14
|
+
__version__ = "0.0.7"
|
15
15
|
|
16
16
|
# This list shouldn't be modified by hand; instead, run the update script.
|
17
17
|
__all__ = [
|
@@ -34,6 +34,8 @@ __all__ = [
|
|
34
34
|
"get_positional_embeddings",
|
35
35
|
"get_rotary_embeddings",
|
36
36
|
"rotary_embeddings",
|
37
|
+
"euler_to_quat",
|
38
|
+
"quat_to_euler",
|
37
39
|
"is_master",
|
38
40
|
"BaseLauncher",
|
39
41
|
"CliLauncher",
|
@@ -78,11 +80,18 @@ __all__ = [
|
|
78
80
|
"save_config",
|
79
81
|
"stage_environment",
|
80
82
|
"to_markdown_table",
|
83
|
+
"jit",
|
81
84
|
"ColoredFormatter",
|
82
85
|
"configure_logging",
|
83
86
|
"one_hot",
|
84
87
|
"partial_flatten",
|
85
88
|
"worker_chunk",
|
89
|
+
"profile",
|
90
|
+
"compute_nan_ratio",
|
91
|
+
"flatten_array",
|
92
|
+
"flatten_pytree",
|
93
|
+
"slice_array",
|
94
|
+
"slice_pytree",
|
86
95
|
"TextBlock",
|
87
96
|
"camelcase_to_snakecase",
|
88
97
|
"colored",
|
@@ -142,6 +151,8 @@ NAME_MAP: dict[str, str] = {
|
|
142
151
|
"get_positional_embeddings": "nn.embeddings",
|
143
152
|
"get_rotary_embeddings": "nn.embeddings",
|
144
153
|
"rotary_embeddings": "nn.embeddings",
|
154
|
+
"euler_to_quat": "nn.geom",
|
155
|
+
"quat_to_euler": "nn.geom",
|
145
156
|
"is_master": "nn.parallel",
|
146
157
|
"BaseLauncher": "task.launchers.base",
|
147
158
|
"CliLauncher": "task.launchers.cli",
|
@@ -186,11 +197,18 @@ NAME_MAP: dict[str, str] = {
|
|
186
197
|
"save_config": "utils.experiments",
|
187
198
|
"stage_environment": "utils.experiments",
|
188
199
|
"to_markdown_table": "utils.experiments",
|
200
|
+
"jit": "utils.jax",
|
189
201
|
"ColoredFormatter": "utils.logging",
|
190
202
|
"configure_logging": "utils.logging",
|
191
203
|
"one_hot": "utils.numpy",
|
192
204
|
"partial_flatten": "utils.numpy",
|
193
205
|
"worker_chunk": "utils.numpy",
|
206
|
+
"profile": "utils.profile",
|
207
|
+
"compute_nan_ratio": "utils.pytree",
|
208
|
+
"flatten_array": "utils.pytree",
|
209
|
+
"flatten_pytree": "utils.pytree",
|
210
|
+
"slice_array": "utils.pytree",
|
211
|
+
"slice_pytree": "utils.pytree",
|
194
212
|
"TextBlock": "utils.text",
|
195
213
|
"camelcase_to_snakecase": "utils.text",
|
196
214
|
"colored": "utils.text",
|
@@ -257,6 +275,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
257
275
|
get_rotary_embeddings,
|
258
276
|
rotary_embeddings,
|
259
277
|
)
|
278
|
+
from xax.nn.geom import euler_to_quat, quat_to_euler
|
260
279
|
from xax.nn.parallel import is_master
|
261
280
|
from xax.task.base import RawConfigType
|
262
281
|
from xax.task.launchers.base import BaseLauncher
|
@@ -299,6 +318,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
299
318
|
stage_environment,
|
300
319
|
to_markdown_table,
|
301
320
|
)
|
321
|
+
from xax.utils.jax import jit
|
302
322
|
from xax.utils.logging import (
|
303
323
|
LOG_ERROR_SUMMARY,
|
304
324
|
LOG_PING,
|
@@ -307,6 +327,14 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
307
327
|
configure_logging,
|
308
328
|
)
|
309
329
|
from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
|
330
|
+
from xax.utils.profile import profile
|
331
|
+
from xax.utils.pytree import (
|
332
|
+
compute_nan_ratio,
|
333
|
+
flatten_array,
|
334
|
+
flatten_pytree,
|
335
|
+
slice_array,
|
336
|
+
slice_pytree,
|
337
|
+
)
|
310
338
|
from xax.utils.text import (
|
311
339
|
TextBlock,
|
312
340
|
camelcase_to_snakecase,
|
xax-0.0.7/xax/nn/geom.py
ADDED
@@ -0,0 +1,75 @@
|
|
1
|
+
"""Defines geometry functions."""
|
2
|
+
|
3
|
+
import jax
|
4
|
+
from jax import numpy as jnp
|
5
|
+
|
6
|
+
|
7
|
+
def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
|
8
|
+
"""Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
quat_4: The quaternion to convert, shape (*, 4).
|
12
|
+
eps: A small epsilon value to avoid division by zero.
|
13
|
+
|
14
|
+
Returns:
|
15
|
+
The roll, pitch, yaw angles with shape (*, 3).
|
16
|
+
"""
|
17
|
+
quat_4 = quat_4 / (jnp.linalg.norm(quat_4, axis=-1, keepdims=True) + eps)
|
18
|
+
w, x, y, z = jnp.split(quat_4, 4, axis=-1)
|
19
|
+
|
20
|
+
# Roll (x-axis rotation)
|
21
|
+
sinr_cosp = 2.0 * (w * x + y * z)
|
22
|
+
cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
|
23
|
+
roll = jnp.arctan2(sinr_cosp, cosr_cosp)
|
24
|
+
|
25
|
+
# Pitch (y-axis rotation)
|
26
|
+
sinp = 2.0 * (w * y - z * x)
|
27
|
+
|
28
|
+
# Handle edge cases where |sinp| >= 1
|
29
|
+
pitch = jnp.where(
|
30
|
+
jnp.abs(sinp) >= 1.0,
|
31
|
+
jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
|
32
|
+
jnp.arcsin(sinp),
|
33
|
+
)
|
34
|
+
|
35
|
+
# Yaw (z-axis rotation)
|
36
|
+
siny_cosp = 2.0 * (w * z + x * y)
|
37
|
+
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
38
|
+
yaw = jnp.arctan2(siny_cosp, cosy_cosp)
|
39
|
+
|
40
|
+
return jnp.concatenate([roll, pitch, yaw], axis=-1)
|
41
|
+
|
42
|
+
|
43
|
+
def euler_to_quat(euler_3: jax.Array) -> jax.Array:
|
44
|
+
"""Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
|
45
|
+
|
46
|
+
Args:
|
47
|
+
euler_3: The roll, pitch, yaw angles, shape (*, 3).
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
The quaternion with shape (*, 4).
|
51
|
+
"""
|
52
|
+
# Extract roll, pitch, yaw from input
|
53
|
+
roll, pitch, yaw = jnp.split(euler_3, 3, axis=-1)
|
54
|
+
|
55
|
+
# Calculate trigonometric functions for each angle
|
56
|
+
cr = jnp.cos(roll * 0.5)
|
57
|
+
sr = jnp.sin(roll * 0.5)
|
58
|
+
cp = jnp.cos(pitch * 0.5)
|
59
|
+
sp = jnp.sin(pitch * 0.5)
|
60
|
+
cy = jnp.cos(yaw * 0.5)
|
61
|
+
sy = jnp.sin(yaw * 0.5)
|
62
|
+
|
63
|
+
# Calculate quaternion components using the conversion formula
|
64
|
+
w = cr * cp * cy + sr * sp * sy
|
65
|
+
x = sr * cp * cy - cr * sp * sy
|
66
|
+
y = cr * sp * cy + sr * cp * sy
|
67
|
+
z = cr * cp * sy - sr * sp * cy
|
68
|
+
|
69
|
+
# Combine into quaternion [w, x, y, z]
|
70
|
+
quat = jnp.concatenate([w, x, y, z], axis=-1)
|
71
|
+
|
72
|
+
# Normalize the quaternion
|
73
|
+
quat = quat / jnp.linalg.norm(quat, axis=-1, keepdims=True)
|
74
|
+
|
75
|
+
return quat
|
@@ -0,0 +1,140 @@
|
|
1
|
+
"""Defines some utility functions for interfacing with Jax."""
|
2
|
+
|
3
|
+
import inspect
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import time
|
7
|
+
from functools import wraps
|
8
|
+
from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
|
9
|
+
|
10
|
+
import jax
|
11
|
+
import jax.numpy as jnp
|
12
|
+
import numpy as np
|
13
|
+
from jax._src import sharding_impls
|
14
|
+
from jax._src.lib import xla_client as xc
|
15
|
+
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
DEFAULT_COMPILE_TIMEOUT = 1.0
|
19
|
+
|
20
|
+
Number = int | float | np.ndarray | jnp.ndarray
|
21
|
+
|
22
|
+
|
23
|
+
P = ParamSpec("P") # For function parameters
|
24
|
+
R = TypeVar("R") # For function return type
|
25
|
+
|
26
|
+
|
27
|
+
def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
|
28
|
+
if isinstance(value, (int, float)):
|
29
|
+
return float(value)
|
30
|
+
if isinstance(value, (np.ndarray, jnp.ndarray)):
|
31
|
+
return float(value.item())
|
32
|
+
raise TypeError(f"Unexpected type: {type(value)}")
|
33
|
+
|
34
|
+
|
35
|
+
def get_hash(obj: object) -> int:
|
36
|
+
"""Get a hash of an object.
|
37
|
+
|
38
|
+
If the object is hashable, use the hash. Otherwise, use the id.
|
39
|
+
"""
|
40
|
+
if hasattr(obj, "__hash__"):
|
41
|
+
return hash(obj)
|
42
|
+
return id(obj)
|
43
|
+
|
44
|
+
|
45
|
+
def jit(
|
46
|
+
in_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
|
47
|
+
out_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
|
48
|
+
static_argnums: int | Sequence[int] | None = None,
|
49
|
+
static_argnames: str | Iterable[str] | None = None,
|
50
|
+
donate_argnums: int | Sequence[int] | None = None,
|
51
|
+
donate_argnames: str | Iterable[str] | None = None,
|
52
|
+
keep_unused: bool = False,
|
53
|
+
device: xc.Device | None = None,
|
54
|
+
backend: str | None = None,
|
55
|
+
inline: bool = False,
|
56
|
+
abstracted_axes: Any | None = None, # noqa: ANN401
|
57
|
+
compiler_options: dict[str, Any] | None = None,
|
58
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
59
|
+
"""Wrapper function that provides utility improvements over Jax's JIT.
|
60
|
+
|
61
|
+
Specifically, this function works on class methods, is toggleable, and
|
62
|
+
detects recompilations by matching hash values.
|
63
|
+
|
64
|
+
This is meant to be used as a decorator factory, and the decorated function
|
65
|
+
calls `wrapped`.
|
66
|
+
"""
|
67
|
+
|
68
|
+
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
|
69
|
+
class JitState:
|
70
|
+
compilation_count = 0
|
71
|
+
last_arg_dict: dict[str, int] | None = None
|
72
|
+
|
73
|
+
sig = inspect.signature(fn)
|
74
|
+
param_names = list(sig.parameters.keys())
|
75
|
+
|
76
|
+
jitted_fn = jax.jit(
|
77
|
+
fn,
|
78
|
+
in_shardings=in_shardings,
|
79
|
+
out_shardings=out_shardings,
|
80
|
+
static_argnums=static_argnums,
|
81
|
+
static_argnames=static_argnames,
|
82
|
+
donate_argnums=donate_argnums,
|
83
|
+
donate_argnames=donate_argnames,
|
84
|
+
keep_unused=keep_unused,
|
85
|
+
device=device,
|
86
|
+
backend=backend,
|
87
|
+
inline=inline,
|
88
|
+
abstracted_axes=abstracted_axes,
|
89
|
+
compiler_options=compiler_options,
|
90
|
+
)
|
91
|
+
|
92
|
+
@wraps(fn)
|
93
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
94
|
+
if os.environ.get("DEBUG", "0") == "1": # skipping during debug
|
95
|
+
return fn(*args, **kwargs)
|
96
|
+
|
97
|
+
do_profile = os.environ.get("JIT_PROFILE", "0") == "1"
|
98
|
+
|
99
|
+
if do_profile:
|
100
|
+
class_name = (args[0].__class__.__name__) + "." if fn.__name__ == "__call__" else ""
|
101
|
+
logger.info(
|
102
|
+
"Currently running %s (count: %s)",
|
103
|
+
f"{class_name}{fn.__name__}",
|
104
|
+
JitState.compilation_count,
|
105
|
+
)
|
106
|
+
|
107
|
+
start_time = time.time()
|
108
|
+
res = jitted_fn(*args, **kwargs)
|
109
|
+
end_time = time.time()
|
110
|
+
runtime = end_time - start_time
|
111
|
+
|
112
|
+
# if this is true, if runtime is higher than COMPILE_TIMEOUT, we recompile
|
113
|
+
# TODO: we should probably reimplement the lower-level jitting logic to avoid this
|
114
|
+
if do_profile:
|
115
|
+
arg_dict = {}
|
116
|
+
for i, arg in enumerate(args):
|
117
|
+
if i < len(param_names):
|
118
|
+
arg_dict[param_names[i]] = get_hash(arg)
|
119
|
+
for k, v in kwargs.items():
|
120
|
+
arg_dict[k] = get_hash(v)
|
121
|
+
|
122
|
+
logger.info("Hashing took %s seconds", runtime)
|
123
|
+
JitState.compilation_count += 1
|
124
|
+
|
125
|
+
if JitState.last_arg_dict is not None:
|
126
|
+
all_keys = set(arg_dict.keys()) | set(JitState.last_arg_dict.keys())
|
127
|
+
for k in all_keys:
|
128
|
+
prev = JitState.last_arg_dict.get(k, "N/A")
|
129
|
+
curr = arg_dict.get(k, "N/A")
|
130
|
+
|
131
|
+
if prev != curr:
|
132
|
+
logger.info("- Arg '%s' hash changed: %s -> %s", k, prev, curr)
|
133
|
+
|
134
|
+
JitState.last_arg_dict = arg_dict
|
135
|
+
|
136
|
+
return cast(R, res)
|
137
|
+
|
138
|
+
return wrapped
|
139
|
+
|
140
|
+
return decorator
|
@@ -0,0 +1,61 @@
|
|
1
|
+
"""Profiling utilities."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import time
|
6
|
+
from functools import wraps
|
7
|
+
from typing import Callable, ParamSpec, TypeVar
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
P = ParamSpec("P") # For function parameters
|
12
|
+
R = TypeVar("R") # For function return type
|
13
|
+
|
14
|
+
|
15
|
+
def profile(fn: Callable[P, R]) -> Callable[P, R]:
|
16
|
+
"""Profiling decorator that tracks function call count and execution time.
|
17
|
+
|
18
|
+
Activated when the PROFILE environment variable is set to "1".
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
A decorated function with profiling capabilities.
|
22
|
+
"""
|
23
|
+
|
24
|
+
class ProfileState:
|
25
|
+
call_count = 0
|
26
|
+
total_time = 0.0
|
27
|
+
|
28
|
+
@wraps(fn)
|
29
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
30
|
+
if os.environ.get("PROFILE", "0") != "1":
|
31
|
+
return fn(*args, **kwargs)
|
32
|
+
|
33
|
+
start_time = time.time()
|
34
|
+
res = fn(*args, **kwargs)
|
35
|
+
end_time = time.time()
|
36
|
+
runtime = end_time - start_time
|
37
|
+
|
38
|
+
ProfileState.call_count += 1
|
39
|
+
ProfileState.total_time += runtime
|
40
|
+
|
41
|
+
# Handle class methods by showing class name
|
42
|
+
if fn.__name__ == "__call__" or (args and hasattr(args[0], "__class__")):
|
43
|
+
try:
|
44
|
+
class_name = args[0].__class__.__name__ + "."
|
45
|
+
except (IndexError, AttributeError):
|
46
|
+
class_name = ""
|
47
|
+
else:
|
48
|
+
class_name = ""
|
49
|
+
|
50
|
+
logger.info(
|
51
|
+
"%s %s - call #%s, took %s seconds, total: %s seconds",
|
52
|
+
class_name,
|
53
|
+
fn.__name__,
|
54
|
+
ProfileState.call_count,
|
55
|
+
runtime,
|
56
|
+
ProfileState.total_time,
|
57
|
+
)
|
58
|
+
|
59
|
+
return res
|
60
|
+
|
61
|
+
return wrapped
|
@@ -0,0 +1,50 @@
|
|
1
|
+
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
|
+
|
3
|
+
import chex
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from jax import Array
|
7
|
+
from jaxtyping import PyTree
|
8
|
+
|
9
|
+
|
10
|
+
def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
11
|
+
"""Get a slice of an array along the first dimension.
|
12
|
+
|
13
|
+
For multi-dimensional arrays, this slices only along the first dimension
|
14
|
+
and keeps all other dimensions intact.
|
15
|
+
"""
|
16
|
+
chex.assert_shape(start, ())
|
17
|
+
chex.assert_shape(slice_length, ())
|
18
|
+
start_indices = (start,) + (0,) * (len(x.shape) - 1)
|
19
|
+
slice_sizes = (slice_length,) + x.shape[1:]
|
20
|
+
|
21
|
+
return jax.lax.dynamic_slice(x, start_indices, slice_sizes)
|
22
|
+
|
23
|
+
|
24
|
+
def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
|
25
|
+
"""Get a slice of a pytree."""
|
26
|
+
return jax.tree_util.tree_map(lambda x: slice_array(x, start, slice_length), pytree)
|
27
|
+
|
28
|
+
|
29
|
+
def flatten_array(x: Array, flatten_size: int) -> Array:
|
30
|
+
"""Flatten an array into a (flatten_size, ...) array."""
|
31
|
+
reshaped = jnp.reshape(x, (flatten_size, *x.shape[2:]))
|
32
|
+
assert reshaped.shape[0] == flatten_size
|
33
|
+
return reshaped
|
34
|
+
|
35
|
+
|
36
|
+
def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
|
37
|
+
"""Flatten a pytree into a (flatten_size, ...) pytree."""
|
38
|
+
return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
|
39
|
+
|
40
|
+
|
41
|
+
def compute_nan_ratio(pytree: PyTree) -> Array:
|
42
|
+
"""Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
|
43
|
+
nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
|
44
|
+
total_counts = jax.tree_util.tree_map(lambda x: x.size, pytree)
|
45
|
+
|
46
|
+
total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
|
47
|
+
total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
|
48
|
+
overall_nan_ratio = jnp.array(total_nans / total_elements)
|
49
|
+
|
50
|
+
return overall_nan_ratio
|
@@ -19,6 +19,7 @@ xax/core/state.py
|
|
19
19
|
xax/nn/__init__.py
|
20
20
|
xax/nn/embeddings.py
|
21
21
|
xax/nn/functions.py
|
22
|
+
xax/nn/geom.py
|
22
23
|
xax/nn/parallel.py
|
23
24
|
xax/task/__init__.py
|
24
25
|
xax/task/base.py
|
@@ -52,6 +53,8 @@ xax/utils/experiments.py
|
|
52
53
|
xax/utils/jax.py
|
53
54
|
xax/utils/logging.py
|
54
55
|
xax/utils/numpy.py
|
56
|
+
xax/utils/profile.py
|
57
|
+
xax/utils/pytree.py
|
55
58
|
xax/utils/tensorboard.py
|
56
59
|
xax/utils/text.py
|
57
60
|
xax/utils/data/__init__.py
|
xax-0.0.6/xax/utils/jax.py
DELETED
@@ -1,14 +0,0 @@
|
|
1
|
-
"""Defines some utility functions for interfacing with Jax."""
|
2
|
-
|
3
|
-
import jax.numpy as jnp
|
4
|
-
import numpy as np
|
5
|
-
|
6
|
-
Number = int | float | np.ndarray | jnp.ndarray
|
7
|
-
|
8
|
-
|
9
|
-
def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
|
10
|
-
if isinstance(value, (int, float)):
|
11
|
-
return float(value)
|
12
|
-
if isinstance(value, (np.ndarray, jnp.ndarray)):
|
13
|
-
return float(value.item())
|
14
|
-
raise TypeError(f"Unexpected type: {type(value)}")
|
{xax-0.0.6 → xax-0.0.7}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.0.6 → xax-0.0.7}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|