xax 0.2.21__tar.gz → 0.2.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.21/xax.egg-info → xax-0.2.22}/PKG-INFO +1 -1
- {xax-0.2.21 → xax-0.2.22}/xax/__init__.py +6 -2
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/train.py +7 -3
- {xax-0.2.21 → xax-0.2.22}/xax/utils/jax.py +109 -7
- {xax-0.2.21 → xax-0.2.22/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.21 → xax-0.2.22}/LICENSE +0 -0
- {xax-0.2.21 → xax-0.2.22}/MANIFEST.in +0 -0
- {xax-0.2.21 → xax-0.2.22}/README.md +0 -0
- {xax-0.2.21 → xax-0.2.22}/pyproject.toml +0 -0
- {xax-0.2.21 → xax-0.2.22}/setup.cfg +0 -0
- {xax-0.2.21 → xax-0.2.22}/setup.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/cli/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/cli/edit_config.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/core/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/core/conf.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/core/state.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/nn/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/nn/embeddings.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/nn/functions.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/nn/geom.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/nn/losses.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/nn/metrics.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/nn/parallel.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/nn/ssm.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/py.typed +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/requirements-dev.txt +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/requirements.txt +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/base.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/launchers/base.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/logger.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/json.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/state.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/process.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/script.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/task/task.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/data/collate.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/debugging.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/experiments.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/logging.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/numpy.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/profile.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/pytree.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/text.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.21 → xax-0.2.22}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.22"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -112,8 +112,10 @@ __all__ = [
|
|
112
112
|
"save_config",
|
113
113
|
"stage_environment",
|
114
114
|
"to_markdown_table",
|
115
|
+
"grad",
|
115
116
|
"jit",
|
116
117
|
"scan",
|
118
|
+
"vmap",
|
117
119
|
"save_jaxpr_dot",
|
118
120
|
"ColoredFormatter",
|
119
121
|
"configure_logging",
|
@@ -287,8 +289,10 @@ NAME_MAP: dict[str, str] = {
|
|
287
289
|
"save_config": "utils.experiments",
|
288
290
|
"stage_environment": "utils.experiments",
|
289
291
|
"to_markdown_table": "utils.experiments",
|
292
|
+
"grad": "utils.jax",
|
290
293
|
"jit": "utils.jax",
|
291
294
|
"scan": "utils.jax",
|
295
|
+
"vmap": "utils.jax",
|
292
296
|
"save_jaxpr_dot": "utils.jaxpr",
|
293
297
|
"ColoredFormatter": "utils.logging",
|
294
298
|
"configure_logging": "utils.logging",
|
@@ -460,7 +464,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
460
464
|
stage_environment,
|
461
465
|
to_markdown_table,
|
462
466
|
)
|
463
|
-
from xax.utils.jax import jit, scan
|
467
|
+
from xax.utils.jax import grad, jit, scan, vmap
|
464
468
|
from xax.utils.jaxpr import save_jaxpr_dot
|
465
469
|
from xax.utils.logging import (
|
466
470
|
LOG_ERROR_SUMMARY,
|
@@ -625,9 +625,13 @@ class TrainMixin(
|
|
625
625
|
grad_metrics = {"grad_norm": grad_norm}
|
626
626
|
|
627
627
|
def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
|
628
|
-
# Clip
|
629
|
-
|
630
|
-
|
628
|
+
# Clip gradients based on global norm, similar to optax.clip_by_global_norm
|
629
|
+
trigger = jnp.squeeze(grad_norm < self.config.global_grad_clip)
|
630
|
+
|
631
|
+
def clip_fn(t: Array) -> Array:
|
632
|
+
return jax.lax.select(trigger, t, (t / grad_norm.astype(t.dtype)) * self.config.global_grad_clip)
|
633
|
+
|
634
|
+
grads = jax.tree.map(clip_fn, grads)
|
631
635
|
|
632
636
|
# Apply the gradient updates.
|
633
637
|
updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
|
@@ -6,13 +6,14 @@ import logging
|
|
6
6
|
import os
|
7
7
|
import time
|
8
8
|
from functools import wraps
|
9
|
-
from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
|
9
|
+
from typing import Any, Callable, Hashable, Iterable, ParamSpec, Sequence, TypeVar, cast
|
10
10
|
|
11
11
|
import jax
|
12
12
|
import jax.numpy as jnp
|
13
13
|
import numpy as np
|
14
14
|
from jax._src import sharding_impls
|
15
15
|
from jax._src.lib import xla_client as xc
|
16
|
+
from jaxtyping import PyTree
|
16
17
|
|
17
18
|
logger = logging.getLogger(__name__)
|
18
19
|
|
@@ -20,6 +21,7 @@ DEFAULT_COMPILE_TIMEOUT = 1.0
|
|
20
21
|
|
21
22
|
Number = int | float | np.ndarray | jnp.ndarray
|
22
23
|
|
24
|
+
T = TypeVar("T", bound=PyTree)
|
23
25
|
|
24
26
|
P = ParamSpec("P") # For function parameters
|
25
27
|
R = TypeVar("R") # For function return type
|
@@ -29,6 +31,9 @@ Carry = TypeVar("Carry")
|
|
29
31
|
X = TypeVar("X")
|
30
32
|
Y = TypeVar("Y")
|
31
33
|
|
34
|
+
F = TypeVar("F", bound=Callable)
|
35
|
+
AxisName = Hashable
|
36
|
+
|
32
37
|
|
33
38
|
@functools.lru_cache(maxsize=None)
|
34
39
|
def disable_jit_level() -> int:
|
@@ -166,6 +171,22 @@ def jit(
|
|
166
171
|
return decorator
|
167
172
|
|
168
173
|
|
174
|
+
def _split_module(tree: T, axis: int = 0) -> list[T]:
|
175
|
+
"""Splits a module in the same way that jax.lax.scan and jax.vmap do.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
tree: The tree to split.
|
179
|
+
axis: The axis to split on.
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
A list of the split trees.
|
183
|
+
"""
|
184
|
+
first_leaf = jax.tree.leaves(tree)[0]
|
185
|
+
num_slices = first_leaf.shape[axis]
|
186
|
+
result = [jax.tree.map(lambda x, idx=i: jnp.take(x, idx, axis=axis), tree) for i in range(num_slices)]
|
187
|
+
return result
|
188
|
+
|
189
|
+
|
169
190
|
def scan(
|
170
191
|
f: Callable[[Carry, X], tuple[Carry, Y]],
|
171
192
|
init: Carry,
|
@@ -195,15 +216,96 @@ def scan(
|
|
195
216
|
if not should_disable_jit(jit_level):
|
196
217
|
return jax.lax.scan(f, init, xs, length, reverse, unroll)
|
197
218
|
|
219
|
+
carry = init
|
220
|
+
ys = []
|
221
|
+
|
198
222
|
if xs is None:
|
199
223
|
if length is None:
|
200
224
|
raise ValueError("length must be provided if xs is None")
|
201
|
-
|
225
|
+
for _ in range(length) if not reverse else range(length - 1, -1, -1):
|
226
|
+
carry, y = f(carry, None) # type: ignore[arg-type]
|
227
|
+
ys.append(y)
|
202
228
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
229
|
+
else:
|
230
|
+
xlist = _split_module(xs, axis=0)
|
231
|
+
if reverse:
|
232
|
+
xlist = xlist[::-1]
|
233
|
+
for x in xlist:
|
234
|
+
carry, y = f(carry, x)
|
235
|
+
ys.append(y)
|
236
|
+
|
237
|
+
if reverse:
|
238
|
+
ys = ys[::-1]
|
239
|
+
|
240
|
+
if not ys:
|
241
|
+
return carry, jnp.array([]) # type: ignore[return-value]
|
208
242
|
|
209
243
|
return carry, jax.tree.map(lambda *ys: jnp.stack(ys), *ys)
|
244
|
+
|
245
|
+
|
246
|
+
def vmap(
|
247
|
+
fun: Callable[P, R],
|
248
|
+
in_axes: int | Sequence[int | None] = 0,
|
249
|
+
jit_level: int | None = None,
|
250
|
+
) -> Callable[P, R]:
|
251
|
+
"""A wrapper around jax.lax.vmap that allows for more flexible tracing.
|
252
|
+
|
253
|
+
If the provided JIT level is below the environment JIT level, we manually
|
254
|
+
unroll the scan function as a for loop.
|
255
|
+
"""
|
256
|
+
if not should_disable_jit(jit_level):
|
257
|
+
return jax.vmap(fun, in_axes=in_axes)
|
258
|
+
|
259
|
+
@functools.wraps(fun)
|
260
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
261
|
+
if kwargs:
|
262
|
+
raise ValueError("vmap does not support keyword arguments")
|
263
|
+
|
264
|
+
ia = in_axes
|
265
|
+
if isinstance(ia, int):
|
266
|
+
ia = [ia] * len(args)
|
267
|
+
elif len(ia) != len(args):
|
268
|
+
raise ValueError("in_axes must be the same length as args")
|
269
|
+
|
270
|
+
if not all(isinstance(a, int) or a is None for a in ia):
|
271
|
+
raise ValueError("in_axes must be a list of integers or None")
|
272
|
+
|
273
|
+
ns = next((len(_split_module(a, axis=i)) for i, a in zip(ia, args, strict=True) if i is not None), None)
|
274
|
+
if ns is None:
|
275
|
+
return fun(*args, **kwargs)
|
276
|
+
split_args = [[a] * ns if i is None else _split_module(a, axis=i) for i, a in zip(ia, args, strict=True)]
|
277
|
+
split_outputs = [fun(*sargs, **kwargs) for sargs in zip(*split_args, strict=True)]
|
278
|
+
|
279
|
+
if not split_outputs:
|
280
|
+
return jnp.array([]) # type: ignore[return-value]
|
281
|
+
|
282
|
+
return jax.tree.map(lambda *ys: jnp.stack(ys), *split_outputs)
|
283
|
+
|
284
|
+
return wrapped
|
285
|
+
|
286
|
+
|
287
|
+
def grad(
|
288
|
+
fun: Callable[P, R],
|
289
|
+
argnums: int | Sequence[int] = 0,
|
290
|
+
has_aux: bool = False,
|
291
|
+
holomorphic: bool = False,
|
292
|
+
allow_int: bool = False,
|
293
|
+
reduce_axes: Sequence[AxisName] = (),
|
294
|
+
jit_level: int | None = None,
|
295
|
+
) -> Callable:
|
296
|
+
"""A wrapper around jax.grad that allows for more flexible tracing.
|
297
|
+
|
298
|
+
We don't do anything special here, we just manually evaluate the function
|
299
|
+
if the JIT level is below the environment JIT level.
|
300
|
+
"""
|
301
|
+
if not should_disable_jit(jit_level):
|
302
|
+
return jax.grad(fun, argnums, has_aux, holomorphic, allow_int, reduce_axes)
|
303
|
+
|
304
|
+
@functools.wraps(fun)
|
305
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> Callable:
|
306
|
+
# Evaluate the function once, then just return the gradient.
|
307
|
+
fun(*args, **kwargs)
|
308
|
+
|
309
|
+
return jax.grad(fun, argnums, has_aux, holomorphic, allow_int, reduce_axes)(*args, **kwargs)
|
310
|
+
|
311
|
+
return wrapped
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|