xax 0.2.21__tar.gz → 0.2.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {xax-0.2.21/xax.egg-info → xax-0.2.22}/PKG-INFO +1 -1
  2. {xax-0.2.21 → xax-0.2.22}/xax/__init__.py +6 -2
  3. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/train.py +7 -3
  4. {xax-0.2.21 → xax-0.2.22}/xax/utils/jax.py +109 -7
  5. {xax-0.2.21 → xax-0.2.22/xax.egg-info}/PKG-INFO +1 -1
  6. {xax-0.2.21 → xax-0.2.22}/LICENSE +0 -0
  7. {xax-0.2.21 → xax-0.2.22}/MANIFEST.in +0 -0
  8. {xax-0.2.21 → xax-0.2.22}/README.md +0 -0
  9. {xax-0.2.21 → xax-0.2.22}/pyproject.toml +0 -0
  10. {xax-0.2.21 → xax-0.2.22}/setup.cfg +0 -0
  11. {xax-0.2.21 → xax-0.2.22}/setup.py +0 -0
  12. {xax-0.2.21 → xax-0.2.22}/xax/cli/__init__.py +0 -0
  13. {xax-0.2.21 → xax-0.2.22}/xax/cli/edit_config.py +0 -0
  14. {xax-0.2.21 → xax-0.2.22}/xax/core/__init__.py +0 -0
  15. {xax-0.2.21 → xax-0.2.22}/xax/core/conf.py +0 -0
  16. {xax-0.2.21 → xax-0.2.22}/xax/core/state.py +0 -0
  17. {xax-0.2.21 → xax-0.2.22}/xax/nn/__init__.py +0 -0
  18. {xax-0.2.21 → xax-0.2.22}/xax/nn/embeddings.py +0 -0
  19. {xax-0.2.21 → xax-0.2.22}/xax/nn/functions.py +0 -0
  20. {xax-0.2.21 → xax-0.2.22}/xax/nn/geom.py +0 -0
  21. {xax-0.2.21 → xax-0.2.22}/xax/nn/losses.py +0 -0
  22. {xax-0.2.21 → xax-0.2.22}/xax/nn/metrics.py +0 -0
  23. {xax-0.2.21 → xax-0.2.22}/xax/nn/parallel.py +0 -0
  24. {xax-0.2.21 → xax-0.2.22}/xax/nn/ssm.py +0 -0
  25. {xax-0.2.21 → xax-0.2.22}/xax/py.typed +0 -0
  26. {xax-0.2.21 → xax-0.2.22}/xax/requirements-dev.txt +0 -0
  27. {xax-0.2.21 → xax-0.2.22}/xax/requirements.txt +0 -0
  28. {xax-0.2.21 → xax-0.2.22}/xax/task/__init__.py +0 -0
  29. {xax-0.2.21 → xax-0.2.22}/xax/task/base.py +0 -0
  30. {xax-0.2.21 → xax-0.2.22}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.2.21 → xax-0.2.22}/xax/task/launchers/base.py +0 -0
  32. {xax-0.2.21 → xax-0.2.22}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.2.21 → xax-0.2.22}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.2.21 → xax-0.2.22}/xax/task/logger.py +0 -0
  35. {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/json.py +0 -0
  38. {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/state.py +0 -0
  39. {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.2.21 → xax-0.2.22}/xax/task/loggers/tensorboard.py +0 -0
  41. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/__init__.py +0 -0
  42. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/artifacts.py +0 -0
  43. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/checkpointing.py +0 -0
  44. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/compile.py +0 -0
  45. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/cpu_stats.py +0 -0
  46. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/gpu_stats.py +0 -0
  48. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/process.py +0 -0
  50. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/runnable.py +0 -0
  51. {xax-0.2.21 → xax-0.2.22}/xax/task/mixins/step_wrapper.py +0 -0
  52. {xax-0.2.21 → xax-0.2.22}/xax/task/script.py +0 -0
  53. {xax-0.2.21 → xax-0.2.22}/xax/task/task.py +0 -0
  54. {xax-0.2.21 → xax-0.2.22}/xax/utils/__init__.py +0 -0
  55. {xax-0.2.21 → xax-0.2.22}/xax/utils/data/__init__.py +0 -0
  56. {xax-0.2.21 → xax-0.2.22}/xax/utils/data/collate.py +0 -0
  57. {xax-0.2.21 → xax-0.2.22}/xax/utils/debugging.py +0 -0
  58. {xax-0.2.21 → xax-0.2.22}/xax/utils/experiments.py +0 -0
  59. {xax-0.2.21 → xax-0.2.22}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.21 → xax-0.2.22}/xax/utils/logging.py +0 -0
  61. {xax-0.2.21 → xax-0.2.22}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.21 → xax-0.2.22}/xax/utils/profile.py +0 -0
  63. {xax-0.2.21 → xax-0.2.22}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.21 → xax-0.2.22}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.21 → xax-0.2.22}/xax/utils/text.py +0 -0
  66. {xax-0.2.21 → xax-0.2.22}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.21 → xax-0.2.22}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.21 → xax-0.2.22}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.21 → xax-0.2.22}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.21 → xax-0.2.22}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.21 → xax-0.2.22}/xax.egg-info/entry_points.txt +0 -0
  72. {xax-0.2.21 → xax-0.2.22}/xax.egg-info/requires.txt +0 -0
  73. {xax-0.2.21 → xax-0.2.22}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.21
3
+ Version: 0.2.22
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.21"
15
+ __version__ = "0.2.22"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -112,8 +112,10 @@ __all__ = [
112
112
  "save_config",
113
113
  "stage_environment",
114
114
  "to_markdown_table",
115
+ "grad",
115
116
  "jit",
116
117
  "scan",
118
+ "vmap",
117
119
  "save_jaxpr_dot",
118
120
  "ColoredFormatter",
119
121
  "configure_logging",
@@ -287,8 +289,10 @@ NAME_MAP: dict[str, str] = {
287
289
  "save_config": "utils.experiments",
288
290
  "stage_environment": "utils.experiments",
289
291
  "to_markdown_table": "utils.experiments",
292
+ "grad": "utils.jax",
290
293
  "jit": "utils.jax",
291
294
  "scan": "utils.jax",
295
+ "vmap": "utils.jax",
292
296
  "save_jaxpr_dot": "utils.jaxpr",
293
297
  "ColoredFormatter": "utils.logging",
294
298
  "configure_logging": "utils.logging",
@@ -460,7 +464,7 @@ if IMPORT_ALL or TYPE_CHECKING:
460
464
  stage_environment,
461
465
  to_markdown_table,
462
466
  )
463
- from xax.utils.jax import jit, scan
467
+ from xax.utils.jax import grad, jit, scan, vmap
464
468
  from xax.utils.jaxpr import save_jaxpr_dot
465
469
  from xax.utils.logging import (
466
470
  LOG_ERROR_SUMMARY,
@@ -625,9 +625,13 @@ class TrainMixin(
625
625
  grad_metrics = {"grad_norm": grad_norm}
626
626
 
627
627
  def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
628
- # Clip the global gradient norm to some desired range.
629
- grad_factor = self.config.global_grad_clip / jnp.maximum(grad_norm, 1e-6)
630
- grads = jax.tree.map(lambda x: x * grad_factor, grads)
628
+ # Clip gradients based on global norm, similar to optax.clip_by_global_norm
629
+ trigger = jnp.squeeze(grad_norm < self.config.global_grad_clip)
630
+
631
+ def clip_fn(t: Array) -> Array:
632
+ return jax.lax.select(trigger, t, (t / grad_norm.astype(t.dtype)) * self.config.global_grad_clip)
633
+
634
+ grads = jax.tree.map(clip_fn, grads)
631
635
 
632
636
  # Apply the gradient updates.
633
637
  updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
@@ -6,13 +6,14 @@ import logging
6
6
  import os
7
7
  import time
8
8
  from functools import wraps
9
- from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
9
+ from typing import Any, Callable, Hashable, Iterable, ParamSpec, Sequence, TypeVar, cast
10
10
 
11
11
  import jax
12
12
  import jax.numpy as jnp
13
13
  import numpy as np
14
14
  from jax._src import sharding_impls
15
15
  from jax._src.lib import xla_client as xc
16
+ from jaxtyping import PyTree
16
17
 
17
18
  logger = logging.getLogger(__name__)
18
19
 
@@ -20,6 +21,7 @@ DEFAULT_COMPILE_TIMEOUT = 1.0
20
21
 
21
22
  Number = int | float | np.ndarray | jnp.ndarray
22
23
 
24
+ T = TypeVar("T", bound=PyTree)
23
25
 
24
26
  P = ParamSpec("P") # For function parameters
25
27
  R = TypeVar("R") # For function return type
@@ -29,6 +31,9 @@ Carry = TypeVar("Carry")
29
31
  X = TypeVar("X")
30
32
  Y = TypeVar("Y")
31
33
 
34
+ F = TypeVar("F", bound=Callable)
35
+ AxisName = Hashable
36
+
32
37
 
33
38
  @functools.lru_cache(maxsize=None)
34
39
  def disable_jit_level() -> int:
@@ -166,6 +171,22 @@ def jit(
166
171
  return decorator
167
172
 
168
173
 
174
+ def _split_module(tree: T, axis: int = 0) -> list[T]:
175
+ """Splits a module in the same way that jax.lax.scan and jax.vmap do.
176
+
177
+ Args:
178
+ tree: The tree to split.
179
+ axis: The axis to split on.
180
+
181
+ Returns:
182
+ A list of the split trees.
183
+ """
184
+ first_leaf = jax.tree.leaves(tree)[0]
185
+ num_slices = first_leaf.shape[axis]
186
+ result = [jax.tree.map(lambda x, idx=i: jnp.take(x, idx, axis=axis), tree) for i in range(num_slices)]
187
+ return result
188
+
189
+
169
190
  def scan(
170
191
  f: Callable[[Carry, X], tuple[Carry, Y]],
171
192
  init: Carry,
@@ -195,15 +216,96 @@ def scan(
195
216
  if not should_disable_jit(jit_level):
196
217
  return jax.lax.scan(f, init, xs, length, reverse, unroll)
197
218
 
219
+ carry = init
220
+ ys = []
221
+
198
222
  if xs is None:
199
223
  if length is None:
200
224
  raise ValueError("length must be provided if xs is None")
201
- xs = cast(X, [None] * length)
225
+ for _ in range(length) if not reverse else range(length - 1, -1, -1):
226
+ carry, y = f(carry, None) # type: ignore[arg-type]
227
+ ys.append(y)
202
228
 
203
- carry = init
204
- ys = []
205
- for x in cast(Iterable, xs):
206
- carry, y = f(carry, x)
207
- ys.append(y)
229
+ else:
230
+ xlist = _split_module(xs, axis=0)
231
+ if reverse:
232
+ xlist = xlist[::-1]
233
+ for x in xlist:
234
+ carry, y = f(carry, x)
235
+ ys.append(y)
236
+
237
+ if reverse:
238
+ ys = ys[::-1]
239
+
240
+ if not ys:
241
+ return carry, jnp.array([]) # type: ignore[return-value]
208
242
 
209
243
  return carry, jax.tree.map(lambda *ys: jnp.stack(ys), *ys)
244
+
245
+
246
+ def vmap(
247
+ fun: Callable[P, R],
248
+ in_axes: int | Sequence[int | None] = 0,
249
+ jit_level: int | None = None,
250
+ ) -> Callable[P, R]:
251
+ """A wrapper around jax.lax.vmap that allows for more flexible tracing.
252
+
253
+ If the provided JIT level is below the environment JIT level, we manually
254
+ unroll the scan function as a for loop.
255
+ """
256
+ if not should_disable_jit(jit_level):
257
+ return jax.vmap(fun, in_axes=in_axes)
258
+
259
+ @functools.wraps(fun)
260
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
261
+ if kwargs:
262
+ raise ValueError("vmap does not support keyword arguments")
263
+
264
+ ia = in_axes
265
+ if isinstance(ia, int):
266
+ ia = [ia] * len(args)
267
+ elif len(ia) != len(args):
268
+ raise ValueError("in_axes must be the same length as args")
269
+
270
+ if not all(isinstance(a, int) or a is None for a in ia):
271
+ raise ValueError("in_axes must be a list of integers or None")
272
+
273
+ ns = next((len(_split_module(a, axis=i)) for i, a in zip(ia, args, strict=True) if i is not None), None)
274
+ if ns is None:
275
+ return fun(*args, **kwargs)
276
+ split_args = [[a] * ns if i is None else _split_module(a, axis=i) for i, a in zip(ia, args, strict=True)]
277
+ split_outputs = [fun(*sargs, **kwargs) for sargs in zip(*split_args, strict=True)]
278
+
279
+ if not split_outputs:
280
+ return jnp.array([]) # type: ignore[return-value]
281
+
282
+ return jax.tree.map(lambda *ys: jnp.stack(ys), *split_outputs)
283
+
284
+ return wrapped
285
+
286
+
287
+ def grad(
288
+ fun: Callable[P, R],
289
+ argnums: int | Sequence[int] = 0,
290
+ has_aux: bool = False,
291
+ holomorphic: bool = False,
292
+ allow_int: bool = False,
293
+ reduce_axes: Sequence[AxisName] = (),
294
+ jit_level: int | None = None,
295
+ ) -> Callable:
296
+ """A wrapper around jax.grad that allows for more flexible tracing.
297
+
298
+ We don't do anything special here, we just manually evaluate the function
299
+ if the JIT level is below the environment JIT level.
300
+ """
301
+ if not should_disable_jit(jit_level):
302
+ return jax.grad(fun, argnums, has_aux, holomorphic, allow_int, reduce_axes)
303
+
304
+ @functools.wraps(fun)
305
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> Callable:
306
+ # Evaluate the function once, then just return the gradient.
307
+ fun(*args, **kwargs)
308
+
309
+ return jax.grad(fun, argnums, has_aux, holomorphic, allow_int, reduce_axes)(*args, **kwargs)
310
+
311
+ return wrapped
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.21
3
+ Version: 0.2.22
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes