xax 0.2.6__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.6/xax.egg-info → xax-0.2.8}/PKG-INFO +1 -1
  2. {xax-0.2.6 → xax-0.2.8}/pyproject.toml +1 -0
  3. {xax-0.2.6 → xax-0.2.8}/xax/__init__.py +52 -2
  4. {xax-0.2.6 → xax-0.2.8}/xax/core/conf.py +1 -1
  5. {xax-0.2.6 → xax-0.2.8}/xax/nn/equinox.py +6 -3
  6. {xax-0.2.6 → xax-0.2.8}/xax/nn/functions.py +8 -5
  7. {xax-0.2.6 → xax-0.2.8}/xax/nn/geom.py +49 -0
  8. {xax-0.2.6 → xax-0.2.8}/xax/task/base.py +2 -2
  9. {xax-0.2.6 → xax-0.2.8}/xax/task/logger.py +11 -6
  10. {xax-0.2.6 → xax-0.2.8}/xax/task/loggers/callback.py +6 -0
  11. {xax-0.2.6 → xax-0.2.8}/xax/task/loggers/json.py +14 -2
  12. {xax-0.2.6 → xax-0.2.8}/xax/task/loggers/state.py +26 -1
  13. {xax-0.2.6 → xax-0.2.8}/xax/task/loggers/stdout.py +4 -2
  14. {xax-0.2.6 → xax-0.2.8}/xax/task/loggers/tensorboard.py +19 -1
  15. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/artifacts.py +11 -8
  16. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/checkpointing.py +108 -143
  17. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/train.py +21 -17
  18. {xax-0.2.6 → xax-0.2.8}/xax/utils/experiments.py +2 -1
  19. {xax-0.2.6 → xax-0.2.8}/xax/utils/jaxpr.py +5 -5
  20. {xax-0.2.6 → xax-0.2.8}/xax/utils/pytree.py +9 -2
  21. {xax-0.2.6 → xax-0.2.8}/xax/utils/text.py +2 -2
  22. {xax-0.2.6 → xax-0.2.8}/xax/utils/types/frozen_dict.py +2 -2
  23. {xax-0.2.6 → xax-0.2.8/xax.egg-info}/PKG-INFO +1 -1
  24. {xax-0.2.6 → xax-0.2.8}/LICENSE +0 -0
  25. {xax-0.2.6 → xax-0.2.8}/MANIFEST.in +0 -0
  26. {xax-0.2.6 → xax-0.2.8}/README.md +0 -0
  27. {xax-0.2.6 → xax-0.2.8}/setup.cfg +0 -0
  28. {xax-0.2.6 → xax-0.2.8}/setup.py +0 -0
  29. {xax-0.2.6 → xax-0.2.8}/xax/core/__init__.py +0 -0
  30. {xax-0.2.6 → xax-0.2.8}/xax/core/state.py +0 -0
  31. {xax-0.2.6 → xax-0.2.8}/xax/nn/__init__.py +0 -0
  32. {xax-0.2.6 → xax-0.2.8}/xax/nn/embeddings.py +0 -0
  33. {xax-0.2.6 → xax-0.2.8}/xax/nn/export.py +0 -0
  34. {xax-0.2.6 → xax-0.2.8}/xax/nn/losses.py +0 -0
  35. {xax-0.2.6 → xax-0.2.8}/xax/nn/norm.py +0 -0
  36. {xax-0.2.6 → xax-0.2.8}/xax/nn/parallel.py +0 -0
  37. {xax-0.2.6 → xax-0.2.8}/xax/nn/ssm.py +0 -0
  38. {xax-0.2.6 → xax-0.2.8}/xax/py.typed +0 -0
  39. {xax-0.2.6 → xax-0.2.8}/xax/requirements-dev.txt +0 -0
  40. {xax-0.2.6 → xax-0.2.8}/xax/requirements.txt +0 -0
  41. {xax-0.2.6 → xax-0.2.8}/xax/task/__init__.py +0 -0
  42. {xax-0.2.6 → xax-0.2.8}/xax/task/launchers/__init__.py +0 -0
  43. {xax-0.2.6 → xax-0.2.8}/xax/task/launchers/base.py +0 -0
  44. {xax-0.2.6 → xax-0.2.8}/xax/task/launchers/cli.py +0 -0
  45. {xax-0.2.6 → xax-0.2.8}/xax/task/launchers/single_process.py +0 -0
  46. {xax-0.2.6 → xax-0.2.8}/xax/task/loggers/__init__.py +0 -0
  47. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/__init__.py +0 -0
  48. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/compile.py +0 -0
  49. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/cpu_stats.py +0 -0
  50. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/data_loader.py +0 -0
  51. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/gpu_stats.py +0 -0
  52. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/logger.py +0 -0
  53. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/process.py +0 -0
  54. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/runnable.py +0 -0
  55. {xax-0.2.6 → xax-0.2.8}/xax/task/mixins/step_wrapper.py +0 -0
  56. {xax-0.2.6 → xax-0.2.8}/xax/task/script.py +0 -0
  57. {xax-0.2.6 → xax-0.2.8}/xax/task/task.py +0 -0
  58. {xax-0.2.6 → xax-0.2.8}/xax/utils/__init__.py +0 -0
  59. {xax-0.2.6 → xax-0.2.8}/xax/utils/data/__init__.py +0 -0
  60. {xax-0.2.6 → xax-0.2.8}/xax/utils/data/collate.py +0 -0
  61. {xax-0.2.6 → xax-0.2.8}/xax/utils/debugging.py +0 -0
  62. {xax-0.2.6 → xax-0.2.8}/xax/utils/jax.py +0 -0
  63. {xax-0.2.6 → xax-0.2.8}/xax/utils/logging.py +0 -0
  64. {xax-0.2.6 → xax-0.2.8}/xax/utils/numpy.py +0 -0
  65. {xax-0.2.6 → xax-0.2.8}/xax/utils/profile.py +0 -0
  66. {xax-0.2.6 → xax-0.2.8}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.2.6 → xax-0.2.8}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.2.6 → xax-0.2.8}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.6 → xax-0.2.8}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.6 → xax-0.2.8}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.6 → xax-0.2.8}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.6 → xax-0.2.8}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -57,6 +57,7 @@ target-version = "py311"
57
57
 
58
58
  select = [
59
59
  "ANN",
60
+ "B",
60
61
  "D",
61
62
  "E",
62
63
  "F",
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.6"
15
+ __version__ = "0.2.8"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -43,9 +43,12 @@ __all__ = [
43
43
  "cubic_bezier_interpolation",
44
44
  "euler_to_quat",
45
45
  "get_projected_gravity_vector_from_quat",
46
+ "normalize",
46
47
  "quat_to_euler",
47
48
  "quat_to_rotmat",
48
49
  "rotate_vector_by_quat",
50
+ "rotation6d_to_rotation_matrix",
51
+ "rotation_matrix_to_rotation6d",
49
52
  "cross_entropy",
50
53
  "cast_norm_type",
51
54
  "get_norm",
@@ -57,8 +60,18 @@ __all__ = [
57
60
  "BaseLauncher",
58
61
  "CliLauncher",
59
62
  "SingleProcessLauncher",
63
+ "LogDistribution",
64
+ "LogError",
65
+ "LogErrorSummary",
66
+ "LogGraph",
67
+ "LogHistogram",
60
68
  "LogImage",
61
69
  "LogLine",
70
+ "LogMesh",
71
+ "LogPing",
72
+ "LogScalar",
73
+ "LogStatus",
74
+ "LogVideo",
62
75
  "Logger",
63
76
  "LoggerImpl",
64
77
  "CallbackLogger",
@@ -66,6 +79,7 @@ __all__ = [
66
79
  "StateLogger",
67
80
  "StdoutLogger",
68
81
  "TensorboardLogger",
82
+ "load_ckpt",
69
83
  "CPUStatsOptions",
70
84
  "DataloaderConfig",
71
85
  "GPUStatsOptions",
@@ -115,6 +129,7 @@ __all__ = [
115
129
  "compute_nan_ratio",
116
130
  "flatten_array",
117
131
  "flatten_pytree",
132
+ "get_pytree_param_count",
118
133
  "pytree_has_nans",
119
134
  "reshuffle_pytree",
120
135
  "reshuffle_pytree_along_dims",
@@ -207,9 +222,12 @@ NAME_MAP: dict[str, str] = {
207
222
  "cubic_bezier_interpolation": "nn.geom",
208
223
  "euler_to_quat": "nn.geom",
209
224
  "get_projected_gravity_vector_from_quat": "nn.geom",
225
+ "normalize": "nn.geom",
210
226
  "quat_to_euler": "nn.geom",
211
227
  "quat_to_rotmat": "nn.geom",
212
228
  "rotate_vector_by_quat": "nn.geom",
229
+ "rotation6d_to_rotation_matrix": "nn.geom",
230
+ "rotation_matrix_to_rotation6d": "nn.geom",
213
231
  "cross_entropy": "nn.losses",
214
232
  "cast_norm_type": "nn.norm",
215
233
  "get_norm": "nn.norm",
@@ -221,8 +239,18 @@ NAME_MAP: dict[str, str] = {
221
239
  "BaseLauncher": "task.launchers.base",
222
240
  "CliLauncher": "task.launchers.cli",
223
241
  "SingleProcessLauncher": "task.launchers.single_process",
242
+ "LogDistribution": "task.logger",
243
+ "LogError": "task.logger",
244
+ "LogErrorSummary": "task.logger",
245
+ "LogGraph": "task.logger",
246
+ "LogHistogram": "task.logger",
224
247
  "LogImage": "task.logger",
225
248
  "LogLine": "task.logger",
249
+ "LogMesh": "task.logger",
250
+ "LogPing": "task.logger",
251
+ "LogScalar": "task.logger",
252
+ "LogStatus": "task.logger",
253
+ "LogVideo": "task.logger",
226
254
  "Logger": "task.logger",
227
255
  "LoggerImpl": "task.logger",
228
256
  "CallbackLogger": "task.loggers.callback",
@@ -230,6 +258,7 @@ NAME_MAP: dict[str, str] = {
230
258
  "StateLogger": "task.loggers.state",
231
259
  "StdoutLogger": "task.loggers.stdout",
232
260
  "TensorboardLogger": "task.loggers.tensorboard",
261
+ "load_ckpt": "task.mixins.checkpointing",
233
262
  "CPUStatsOptions": "task.mixins.cpu_stats",
234
263
  "DataloaderConfig": "task.mixins.data_loader",
235
264
  "GPUStatsOptions": "task.mixins.gpu_stats",
@@ -279,6 +308,7 @@ NAME_MAP: dict[str, str] = {
279
308
  "compute_nan_ratio": "utils.pytree",
280
309
  "flatten_array": "utils.pytree",
281
310
  "flatten_pytree": "utils.pytree",
311
+ "get_param_count": "utils.pytree",
282
312
  "pytree_has_nans": "utils.pytree",
283
313
  "reshuffle_pytree": "utils.pytree",
284
314
  "reshuffle_pytree_along_dims": "utils.pytree",
@@ -372,9 +402,12 @@ if IMPORT_ALL or TYPE_CHECKING:
372
402
  cubic_bezier_interpolation,
373
403
  euler_to_quat,
374
404
  get_projected_gravity_vector_from_quat,
405
+ normalize,
375
406
  quat_to_euler,
376
407
  quat_to_rotmat,
377
408
  rotate_vector_by_quat,
409
+ rotation6d_to_rotation_matrix,
410
+ rotation_matrix_to_rotation6d,
378
411
  )
379
412
  from xax.nn.losses import cross_entropy
380
413
  from xax.nn.norm import NormType, cast_norm_type, get_norm
@@ -384,12 +417,28 @@ if IMPORT_ALL or TYPE_CHECKING:
384
417
  from xax.task.launchers.base import BaseLauncher
385
418
  from xax.task.launchers.cli import CliLauncher
386
419
  from xax.task.launchers.single_process import SingleProcessLauncher
387
- from xax.task.logger import Logger, LoggerImpl, LogImage, LogLine
420
+ from xax.task.logger import (
421
+ LogDistribution,
422
+ LogError,
423
+ LogErrorSummary,
424
+ Logger,
425
+ LoggerImpl,
426
+ LogGraph,
427
+ LogHistogram,
428
+ LogImage,
429
+ LogLine,
430
+ LogMesh,
431
+ LogPing,
432
+ LogScalar,
433
+ LogStatus,
434
+ LogVideo,
435
+ )
388
436
  from xax.task.loggers.callback import CallbackLogger
389
437
  from xax.task.loggers.json import JsonLogger
390
438
  from xax.task.loggers.state import StateLogger
391
439
  from xax.task.loggers.stdout import StdoutLogger
392
440
  from xax.task.loggers.tensorboard import TensorboardLogger
441
+ from xax.task.mixins.checkpointing import load_ckpt
393
442
  from xax.task.mixins.cpu_stats import CPUStatsOptions
394
443
  from xax.task.mixins.data_loader import DataloaderConfig
395
444
  from xax.task.mixins.gpu_stats import GPUStatsOptions
@@ -439,6 +488,7 @@ if IMPORT_ALL or TYPE_CHECKING:
439
488
  compute_nan_ratio,
440
489
  flatten_array,
441
490
  flatten_pytree,
491
+ get_pytree_param_count,
442
492
  pytree_has_nans,
443
493
  reshuffle_pytree,
444
494
  reshuffle_pytree_along_dims,
@@ -26,7 +26,7 @@ def field(value: FieldType, **kwargs: str) -> FieldType:
26
26
  metadata: dict[str, Any] = {}
27
27
  metadata.update(kwargs)
28
28
 
29
- if hasattr(value, "__call__"):
29
+ if hasattr(value, "__call__"): # noqa: B004
30
30
  return field_base(default_factory=value, metadata=metadata)
31
31
  if value.__class__.__hash__ is None:
32
32
  return field_base(default_factory=lambda: value, metadata=metadata)
@@ -68,8 +68,8 @@ def _infer_activation(activation: ActivationFunction) -> Callable:
68
68
  return lambda x: x
69
69
  try:
70
70
  return getattr(jax.nn, activation)
71
- except AttributeError:
72
- raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
71
+ except AttributeError as err:
72
+ raise ValueError(f"Activation function `{activation}` not found in `jax.nn`") from err
73
73
 
74
74
 
75
75
  def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
@@ -100,7 +100,7 @@ def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.ML
100
100
  def export_eqx_mlp(
101
101
  model: eqx.nn.MLP,
102
102
  output_path: str | Path,
103
- dtype: jax.numpy.dtype = eqx._misc.default_floating_dtype(),
103
+ dtype: jax.numpy.dtype | None = None,
104
104
  ) -> None:
105
105
  """Serialize an Equinox MLP to a .eqx file.
106
106
 
@@ -109,6 +109,9 @@ def export_eqx_mlp(
109
109
  output_path: The path to save the exported model.
110
110
  dtype: The dtype of the model.
111
111
  """
112
+ if dtype is None:
113
+ dtype = eqx._misc.default_floating_dtype()
114
+
112
115
  activation = model.activation.__name__
113
116
  final_activation = model.final_activation.__name__
114
117
 
@@ -1,5 +1,5 @@
1
1
  # mypy: disable-error-code="override"
2
- """Defines helper Torch functions."""
2
+ """Defines helper Jax functions."""
3
3
 
4
4
  import random
5
5
  from dataclasses import is_dataclass
@@ -58,13 +58,16 @@ def recursive_chunk(item: Any, num_chunks: int, dim: int = 0) -> Iterable[Any]:
58
58
  yield from np.array_split(item, num_chunks, axis=dim)
59
59
  elif is_dataclass(item):
60
60
  yield from (
61
- item.__class__(**{k: i for k, i in zip(item.__dict__, ii)})
62
- for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()))
61
+ item.__class__(**{k: i for k, i in zip(item.__dict__, ii, strict=True)})
62
+ for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()), strict=False)
63
63
  )
64
64
  elif isinstance(item, Mapping):
65
- yield from (dict(zip(item, ii)) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values())))
65
+ yield from (
66
+ dict(zip(item, ii, strict=False))
67
+ for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values()), strict=False)
68
+ )
66
69
  elif isinstance(item, Sequence):
67
- yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item)))
70
+ yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item), strict=False))
68
71
  else:
69
72
  yield from (item for _ in range(num_chunks))
70
73
 
@@ -1,5 +1,6 @@
1
1
  """Defines geometry functions."""
2
2
 
3
+ import chex
3
4
  from jax import numpy as jnp
4
5
  from jaxtyping import Array
5
6
 
@@ -211,3 +212,51 @@ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
211
212
  ],
212
213
  axis=-2,
213
214
  )
215
+
216
+
217
+ def normalize(v: jnp.ndarray, axis: int = -1, eps: float = 1e-8) -> jnp.ndarray:
218
+ norm = jnp.linalg.norm(v, axis=axis, keepdims=True)
219
+ return v / jnp.clip(norm, a_min=eps)
220
+
221
+
222
+ def rotation6d_to_rotation_matrix(r6d: jnp.ndarray) -> jnp.ndarray:
223
+ """Convert 6D rotation representation to rotation matrix.
224
+
225
+ From https://arxiv.org/pdf/1812.07035, Appendix B
226
+
227
+ Args:
228
+ r6d: The 6D rotation representation, shape (*, 6).
229
+
230
+ Returns:
231
+ The rotation matrix, shape (*, 3, 3).
232
+ """
233
+ chex.assert_shape(r6d, (..., 6))
234
+ shape = r6d.shape
235
+ flat = r6d.reshape(-1, 6)
236
+ a_1 = flat[:, 0:3]
237
+ a_2 = flat[:, 3:6]
238
+
239
+ b_1 = normalize(a_1, axis=-1)
240
+
241
+ # Reordered Gram-Schmidt orthonormalization.
242
+ b_3 = normalize(jnp.cross(b_1, a_2), axis=-1)
243
+ b_2 = jnp.cross(b_3, b_1)
244
+
245
+ rotation_matrix = jnp.stack([b_1, b_2, b_3], axis=-1)
246
+ return rotation_matrix.reshape(shape[:-1] + (3, 3))
247
+
248
+
249
+ def rotation_matrix_to_rotation6d(rotation_matrix: jnp.ndarray) -> jnp.ndarray:
250
+ """Convert rotation matrix to 6D rotation representation.
251
+
252
+ Args:
253
+ rotation_matrix: The rotation matrix, shape (*, 3, 3).
254
+
255
+ Returns:
256
+ The 6D rotation representation, shape (*, 6).
257
+ """
258
+ chex.assert_shape(rotation_matrix, (..., 3, 3))
259
+ shape = rotation_matrix.shape
260
+ # Simply concatenate a1 and a2 from SO(3)
261
+ r6d = jnp.concatenate([rotation_matrix[..., 0], rotation_matrix[..., 1]], axis=-1)
262
+ return r6d.reshape(shape[:-2] + (6,))
@@ -184,8 +184,8 @@ class BaseTask(Generic[Config]):
184
184
 
185
185
  # Attempts to load any paths as configs.
186
186
  is_path = [Path(arg).is_file() or (task_path / arg).is_file() for arg in args]
187
- paths = [arg for arg, is_path in zip(args, is_path) if is_path]
188
- non_paths = [arg for arg, is_path in zip(args, is_path) if not is_path]
187
+ paths = [arg for arg, is_path in zip(args, is_path, strict=True) if is_path]
188
+ non_paths = [arg for arg, is_path in zip(args, is_path, strict=True) if not is_path]
189
189
  if paths:
190
190
  cfg = OmegaConf.merge(cfg, *(get_config(path, task_path) for path in paths))
191
191
  cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(non_paths))
@@ -462,11 +462,11 @@ class LoggerImpl(ABC):
462
462
 
463
463
  self.tickers = {phase: IntervalTicker(log_interval_seconds) for phase in get_args(Phase)}
464
464
 
465
- def start(self) -> None:
466
- pass
465
+ @abstractmethod
466
+ def start(self) -> None: ...
467
467
 
468
- def stop(self) -> None:
469
- pass
468
+ @abstractmethod
469
+ def stop(self) -> None: ...
470
470
 
471
471
  @abstractmethod
472
472
  def write(self, line: LogLine) -> None:
@@ -476,6 +476,7 @@ class LoggerImpl(ABC):
476
476
  line: The line to write.
477
477
  """
478
478
 
479
+ @abstractmethod
479
480
  def write_error_summary(self, error_summary: LogErrorSummary) -> None:
480
481
  """Handles writing an error summary.
481
482
 
@@ -483,6 +484,7 @@ class LoggerImpl(ABC):
483
484
  error_summary: The error summary to write.
484
485
  """
485
486
 
487
+ @abstractmethod
486
488
  def write_error(self, error: LogError) -> None:
487
489
  """Handles writing an error line.
488
490
 
@@ -490,6 +492,7 @@ class LoggerImpl(ABC):
490
492
  error: The error information to write.
491
493
  """
492
494
 
495
+ @abstractmethod
493
496
  def write_status(self, status: LogStatus) -> None:
494
497
  """Handles writing a status line.
495
498
 
@@ -497,6 +500,7 @@ class LoggerImpl(ABC):
497
500
  status: The status to write.
498
501
  """
499
502
 
503
+ @abstractmethod
500
504
  def write_ping(self, ping: LogPing) -> None:
501
505
  """Handles writing a ping line.
502
506
 
@@ -504,6 +508,7 @@ class LoggerImpl(ABC):
504
508
  ping: The ping to write.
505
509
  """
506
510
 
511
+ @abstractmethod
507
512
  def log_file(self, name: str, contents: str) -> None:
508
513
  """Logs a large text file.
509
514
 
@@ -621,7 +626,7 @@ class Logger:
621
626
  return
622
627
  line = self.pack(state)
623
628
  self.clear()
624
- for lg in (lg for lg, should_log in zip(self.loggers, should_log) if should_log):
629
+ for lg in (lg for lg, should_log in zip(self.loggers, should_log, strict=False) if should_log):
625
630
  lg.write(line)
626
631
 
627
632
  def write_error_summary(self, error_summary: str) -> None:
@@ -1045,7 +1050,7 @@ class Logger:
1045
1050
  line_spacing=line_spacing,
1046
1051
  centered=centered,
1047
1052
  )
1048
- for img, label in zip(images, labels)
1053
+ for img, label in zip(images, labels, strict=True)
1049
1054
  ]
1050
1055
  tiled = tile_images([img.image for img in labeled], sep)
1051
1056
 
@@ -25,6 +25,12 @@ class CallbackLogger(LoggerImpl):
25
25
  self.ping_callback = ping_callback
26
26
  self.file_callback = file_callback
27
27
 
28
+ def start(self) -> None:
29
+ pass
30
+
31
+ def stop(self) -> None:
32
+ pass
33
+
28
34
  def write(self, line: LogLine) -> None:
29
35
  self.callback(line)
30
36
 
@@ -2,13 +2,13 @@
2
2
 
3
3
  import json
4
4
  import sys
5
- from dataclasses import asdict
6
5
  from typing import Any, Literal, Mapping, TextIO
7
6
 
8
7
  from jaxtyping import Array
9
8
 
10
9
  from xax.task.logger import (
11
10
  LogError,
11
+ LogErrorSummary,
12
12
  LoggerImpl,
13
13
  LogLine,
14
14
  LogPing,
@@ -58,6 +58,12 @@ class JsonLogger(LoggerImpl):
58
58
  self.line_sep = line_sep
59
59
  self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
60
60
 
61
+ def start(self) -> None:
62
+ pass
63
+
64
+ def stop(self) -> None:
65
+ pass
66
+
61
67
  @property
62
68
  def fp(self) -> TextIO:
63
69
  return self.log_stream
@@ -67,7 +73,7 @@ class JsonLogger(LoggerImpl):
67
73
  return self.err_log_stream
68
74
 
69
75
  def get_json(self, line: LogLine) -> str:
70
- data: dict = {"state": asdict(line.state)}
76
+ data: dict = {"state": line.state.to_dict()}
71
77
 
72
78
  def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
73
79
  for namespace, values in log.items():
@@ -88,6 +94,12 @@ class JsonLogger(LoggerImpl):
88
94
  if self.flush_immediately:
89
95
  self.fp.flush()
90
96
 
97
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
98
+ pass
99
+
100
+ def log_file(self, name: str, contents: str) -> None:
101
+ pass
102
+
91
103
  def write_error(self, error: LogError) -> None:
92
104
  self.err_fp.write(error.message)
93
105
  if error.location is not None:
@@ -3,7 +3,14 @@
3
3
  from pathlib import Path
4
4
  from typing import Literal
5
5
 
6
- from xax.task.logger import LoggerImpl, LogLine
6
+ from xax.task.logger import (
7
+ LogError,
8
+ LogErrorSummary,
9
+ LoggerImpl,
10
+ LogLine,
11
+ LogPing,
12
+ LogStatus,
13
+ )
7
14
 
8
15
 
9
16
  class StateLogger(LoggerImpl):
@@ -30,3 +37,21 @@ class StateLogger(LoggerImpl):
30
37
 
31
38
  def write(self, line: LogLine) -> None:
32
39
  pass
40
+
41
+ def start(self) -> None:
42
+ pass
43
+
44
+ def stop(self) -> None:
45
+ pass
46
+
47
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
48
+ pass
49
+
50
+ def write_error(self, error: LogError) -> None:
51
+ pass
52
+
53
+ def write_status(self, status: LogStatus) -> None:
54
+ pass
55
+
56
+ def write_ping(self, ping: LogPing) -> None:
57
+ pass
@@ -79,11 +79,13 @@ class StdoutLogger(LoggerImpl):
79
79
  self.error_summary: tuple[str, datetime.datetime] | None = None
80
80
 
81
81
  def start(self) -> None:
82
- return super().start()
82
+ pass
83
83
 
84
84
  def stop(self) -> None:
85
85
  self.write_queues()
86
- return super().stop()
86
+
87
+ def log_file(self, name: str, contents: str) -> None:
88
+ pass
87
89
 
88
90
  def write_separator(self) -> None:
89
91
  self.write_fp.write("\033[2J\033[H")
@@ -12,7 +12,7 @@ from typing import TypeVar
12
12
 
13
13
  from xax.core.state import Phase
14
14
  from xax.nn.parallel import is_master
15
- from xax.task.logger import LoggerImpl, LogLine
15
+ from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
16
16
  from xax.utils.jax import as_float
17
17
  from xax.utils.logging import LOG_STATUS, port_is_busy
18
18
  from xax.utils.tensorboard import TensorboardWriter, TensorboardWriters
@@ -236,3 +236,21 @@ class TensorboardLogger(LoggerImpl):
236
236
  for name, contents in self.files.items():
237
237
  writer.add_text(name, contents)
238
238
  self.files.clear()
239
+
240
+ def start(self) -> None:
241
+ pass
242
+
243
+ def stop(self) -> None:
244
+ pass
245
+
246
+ def write_error(self, error: LogError) -> None:
247
+ pass
248
+
249
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
250
+ pass
251
+
252
+ def write_ping(self, ping: LogPing) -> None:
253
+ pass
254
+
255
+ def write_status(self, status: LogStatus) -> None:
256
+ pass
@@ -31,11 +31,13 @@ Config = TypeVar("Config", bound=ArtifactsConfig)
31
31
 
32
32
  class ArtifactsMixin(BaseTask[Config]):
33
33
  _exp_dir: Path | None
34
+ _stage_dir: Path | None
34
35
 
35
36
  def __init__(self, config: Config) -> None:
36
37
  super().__init__(config)
37
38
 
38
39
  self._exp_dir = None
40
+ self._stage_dir = None
39
41
 
40
42
  @functools.cached_property
41
43
  def run_dir(self) -> Path:
@@ -75,15 +77,16 @@ class ArtifactsMixin(BaseTask[Config]):
75
77
  logger.log(LOG_STATUS, self._exp_dir)
76
78
  return self._exp_dir
77
79
 
78
- @functools.lru_cache(maxsize=None)
79
80
  def stage_environment(self) -> Path | None:
80
- stage_dir = (self.exp_dir / "code").resolve()
81
- try:
82
- stage_environment(self, stage_dir)
83
- except Exception:
84
- logger.exception("Failed to stage environment!")
85
- return None
86
- return stage_dir
81
+ if self._stage_dir is None:
82
+ stage_dir = (self.exp_dir / "code").resolve()
83
+ try:
84
+ stage_environment(self, stage_dir)
85
+ except Exception:
86
+ logger.exception("Failed to stage environment!")
87
+ return None
88
+ self._stage_dir = stage_dir
89
+ return self._stage_dir
87
90
 
88
91
  def on_training_end(self, state: State) -> State:
89
92
  state = super().on_training_end(state)
@@ -52,6 +52,114 @@ class CheckpointingConfig(ArtifactsConfig):
52
52
  Config = TypeVar("Config", bound=CheckpointingConfig)
53
53
 
54
54
 
55
+ @overload
56
+ def load_ckpt(
57
+ path: Path,
58
+ *,
59
+ part: Literal["all"],
60
+ model_template: PyTree,
61
+ optimizer_template: PyTree,
62
+ opt_state_template: PyTree,
63
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
64
+
65
+
66
+ @overload
67
+ def load_ckpt(
68
+ path: Path,
69
+ *,
70
+ part: Literal["model_state_config"],
71
+ model_template: PyTree,
72
+ ) -> tuple[PyTree, State, DictConfig]: ...
73
+
74
+
75
+ @overload
76
+ def load_ckpt(path: Path, *, part: Literal["model"], model_template: PyTree) -> PyTree: ...
77
+
78
+
79
+ @overload
80
+ def load_ckpt(path: Path, *, part: Literal["opt"], optimizer_template: PyTree) -> optax.GradientTransformation: ...
81
+
82
+
83
+ @overload
84
+ def load_ckpt(path: Path, *, part: Literal["opt_state"], opt_state_template: PyTree) -> optax.OptState: ...
85
+
86
+
87
+ @overload
88
+ def load_ckpt(path: Path, *, part: Literal["state"]) -> State: ...
89
+
90
+
91
+ @overload
92
+ def load_ckpt(path: Path, *, part: Literal["config"]) -> DictConfig: ...
93
+
94
+
95
+ def load_ckpt(
96
+ path: str | Path,
97
+ *,
98
+ part: CheckpointPart = "model",
99
+ model_template: PyTree | None = None,
100
+ optimizer_template: PyTree | None = None,
101
+ opt_state_template: PyTree | None = None,
102
+ ) -> (
103
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
104
+ | tuple[PyTree, State, DictConfig]
105
+ | PyTree
106
+ | optax.GradientTransformation
107
+ | optax.OptState
108
+ | State
109
+ | DictConfig
110
+ ):
111
+ with tarfile.open(path, "r:gz") as tar:
112
+
113
+ def get_model() -> PyTree:
114
+ if model_template is None:
115
+ raise ValueError("model_template must be provided to load model weights")
116
+ if (model := tar.extractfile("model")) is None:
117
+ raise ValueError(f"Checkpoint does not contain a model file: {path}")
118
+ return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
119
+
120
+ def get_opt() -> optax.GradientTransformation:
121
+ if optimizer_template is None:
122
+ raise ValueError("optimizer_template must be provided to load optimizer")
123
+ if (opt := tar.extractfile("optimizer")) is None:
124
+ raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
125
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
126
+
127
+ def get_opt_state() -> optax.OptState:
128
+ if opt_state_template is None:
129
+ raise ValueError("opt_state_template must be provided to load optimizer state")
130
+ if (opt_state := tar.extractfile("opt_state")) is None:
131
+ raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
132
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
133
+
134
+ def get_state() -> State:
135
+ if (state := tar.extractfile("state")) is None:
136
+ raise ValueError(f"Checkpoint does not contain a state file: {path}")
137
+ return State.from_dict(**json.loads(state.read().decode()))
138
+
139
+ def get_config() -> DictConfig:
140
+ if (config := tar.extractfile("config")) is None:
141
+ raise ValueError(f"Checkpoint does not contain a config file: {path}")
142
+ return cast(DictConfig, OmegaConf.load(config))
143
+
144
+ match part:
145
+ case "model":
146
+ return get_model()
147
+ case "opt":
148
+ return get_opt()
149
+ case "opt_state":
150
+ return get_opt_state()
151
+ case "state":
152
+ return get_state()
153
+ case "config":
154
+ return get_config()
155
+ case "model_state_config":
156
+ return get_model(), get_state(), get_config()
157
+ case "all":
158
+ return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
159
+ case _:
160
+ raise ValueError(f"Invalid checkpoint part: {part}")
161
+
162
+
55
163
  class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
56
164
  def __init__(self, config: Config) -> None:
57
165
  super().__init__(config)
@@ -82,149 +190,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
82
190
  return True
83
191
  return False
84
192
 
85
- @overload
86
- def load_ckpt_with_template(
87
- self,
88
- path: Path,
89
- *,
90
- part: Literal["all"],
91
- model_template: PyTree,
92
- optimizer_template: PyTree,
93
- opt_state_template: PyTree,
94
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
95
-
96
- @overload
97
- def load_ckpt_with_template(
98
- self,
99
- path: Path,
100
- *,
101
- part: Literal["model_state_config"],
102
- model_template: PyTree,
103
- ) -> tuple[PyTree, State, Config]: ...
104
-
105
- @overload
106
- def load_ckpt_with_template(
107
- self,
108
- path: Path,
109
- *,
110
- part: Literal["model"],
111
- model_template: PyTree,
112
- ) -> PyTree: ...
113
-
114
- @overload
115
- def load_ckpt_with_template(
116
- self,
117
- path: Path,
118
- *,
119
- part: Literal["opt"],
120
- optimizer_template: PyTree,
121
- ) -> optax.GradientTransformation: ...
122
-
123
- @overload
124
- def load_ckpt_with_template(
125
- self,
126
- path: Path,
127
- *,
128
- part: Literal["opt_state"],
129
- opt_state_template: PyTree,
130
- ) -> optax.OptState: ...
131
-
132
- @overload
133
- def load_ckpt_with_template(
134
- self,
135
- path: Path,
136
- *,
137
- part: Literal["state"],
138
- ) -> State: ...
139
-
140
- @overload
141
- def load_ckpt_with_template(
142
- self,
143
- path: Path,
144
- *,
145
- part: Literal["config"],
146
- ) -> Config: ...
147
-
148
- def load_ckpt_with_template(
149
- self,
150
- path: Path,
151
- *,
152
- part: CheckpointPart = "all",
153
- model_template: PyTree | None = None,
154
- optimizer_template: PyTree | None = None,
155
- opt_state_template: PyTree | None = None,
156
- ) -> (
157
- tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
158
- | tuple[PyTree, State, Config]
159
- | PyTree
160
- | optax.GradientTransformation
161
- | optax.OptState
162
- | State
163
- | Config
164
- ):
165
- """Load a checkpoint.
166
-
167
- Args:
168
- path: Path to the checkpoint directory
169
- part: Which part of the checkpoint to load
170
- model_template: Template model with correct structure but uninitialized weights
171
- optimizer_template: Template optimizer with correct structure but uninitialized weights
172
- opt_state_template: Template optimizer state with correct structure but uninitialized weights
173
-
174
- Returns:
175
- The requested checkpoint components
176
- """
177
- with tarfile.open(path, "r:gz") as tar:
178
-
179
- def get_model() -> PyTree:
180
- if model_template is None:
181
- raise ValueError("model_template must be provided to load model weights")
182
- if (model := tar.extractfile("model")) is None:
183
- raise ValueError(f"Checkpoint does not contain a model file: {path}")
184
- return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
185
-
186
- def get_opt() -> optax.GradientTransformation:
187
- if optimizer_template is None:
188
- raise ValueError("optimizer_template must be provided to load optimizer")
189
- if (opt := tar.extractfile("optimizer")) is None:
190
- raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
191
- return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
192
-
193
- def get_opt_state() -> optax.OptState:
194
- if opt_state_template is None:
195
- raise ValueError("opt_state_template must be provided to load optimizer state")
196
- if (opt_state := tar.extractfile("opt_state")) is None:
197
- raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
198
- return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
199
-
200
- def get_state() -> State:
201
- if (state := tar.extractfile("state")) is None:
202
- raise ValueError(f"Checkpoint does not contain a state file: {path}")
203
- return State.from_dict(**json.loads(state.read().decode()))
204
-
205
- def get_config() -> Config:
206
- if (config := tar.extractfile("config")) is None:
207
- raise ValueError(f"Checkpoint does not contain a config file: {path}")
208
- return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
209
-
210
- match part:
211
- case "model":
212
- return get_model()
213
- case "opt":
214
- return get_opt()
215
- case "opt_state":
216
- return get_opt_state()
217
- case "state":
218
- return get_state()
219
- case "config":
220
- return get_config()
221
- case "model_state_config":
222
- return get_model(), get_state(), get_config()
223
- case "all":
224
- return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
225
- case _:
226
- raise ValueError(f"Invalid checkpoint part: {part}")
227
-
228
193
  def save_checkpoint(
229
194
  self,
230
195
  model: PyTree | None = None,
@@ -40,7 +40,7 @@ from xax.core.state import Phase, State
40
40
  from xax.nn.functions import set_random_seed
41
41
  from xax.nn.parallel import is_master
42
42
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
43
- from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
43
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart, load_ckpt
44
44
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
45
45
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
46
46
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -57,6 +57,7 @@ from xax.utils.experiments import (
57
57
  )
58
58
  from xax.utils.jax import jit as xax_jit
59
59
  from xax.utils.logging import LOG_PING, LOG_STATUS
60
+ from xax.utils.pytree import get_pytree_param_count
60
61
  from xax.utils.text import highlight_exception_message, show_info
61
62
  from xax.utils.types.frozen_dict import FrozenDict
62
63
 
@@ -360,6 +361,7 @@ class TrainMixin(
360
361
  model = self.get_model(key)
361
362
  state = State.init_state()
362
363
 
364
+ self.log_model_size(model)
363
365
  if not load_optimizer:
364
366
  return model, state
365
367
 
@@ -450,44 +452,43 @@ class TrainMixin(
450
452
  match part:
451
453
  case "model_state_config":
452
454
  model_spec = eqx.filter_eval_shape(self.get_model, key)
453
- return self.load_ckpt_with_template(path, part="model_state_config", model_template=model_spec)
455
+ model, state, config = load_ckpt(path, part="model_state_config", model_template=model_spec)
456
+ config = self.get_config(config, use_cli=False)
457
+ return model, state, config
454
458
 
455
459
  case "model":
456
460
  model_spec = eqx.filter_eval_shape(self.get_model, key)
457
- return self.load_ckpt_with_template(path, part="model", model_template=model_spec)
458
-
459
- case "config":
460
- return self.load_ckpt_with_template(path, part="config")
461
+ return load_ckpt(path, part="model", model_template=model_spec)
461
462
 
462
463
  case "opt":
463
464
  optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
464
- return self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
465
+ return load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
465
466
 
466
467
  case "opt_state":
467
468
  if model is None:
468
469
  model_spec = eqx.filter_eval_shape(self.get_model, key)
469
- model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
470
+ model = load_ckpt(path, part="model", model_template=model_spec)
470
471
  if optimizer is None:
471
472
  optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
472
- optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
473
+ optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
473
474
  opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
474
- return self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
475
+ return load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
475
476
 
476
477
  case "state":
477
- return self.load_ckpt_with_template(path, part="state")
478
+ return load_ckpt(path, part="state")
478
479
 
479
480
  case "config":
480
- return self.load_ckpt_with_template(path, part="config")
481
+ return self.get_config(load_ckpt(path, part="config"), use_cli=False)
481
482
 
482
483
  case "all":
483
484
  model_spec = eqx.filter_eval_shape(self.get_model, key)
484
- model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
485
+ model = load_ckpt(path, part="model", model_template=model_spec)
485
486
  optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
486
- optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
487
+ optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
487
488
  opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
488
- opt_state = self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
489
- state = self.load_ckpt_with_template(path, part="state")
490
- config = self.load_ckpt_with_template(path, part="config")
489
+ opt_state = load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
490
+ state = load_ckpt(path, part="state")
491
+ config = self.get_config(load_ckpt(path, part="config"), use_cli=False)
491
492
  return model, optimizer, opt_state, state, config
492
493
 
493
494
  case _:
@@ -683,6 +684,9 @@ class TrainMixin(
683
684
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
684
685
  self.logger.log_file("info.json", get_info_json())
685
686
 
687
+ def log_model_size(self, model: PyTree) -> None:
688
+ logger.info("Model size: %s", f"{get_pytree_param_count(model):,}")
689
+
686
690
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
687
691
  return eqx.is_inexact_array(item)
688
692
 
@@ -749,7 +749,8 @@ class BaseFileDownloader(ABC):
749
749
  f"We detected some HTML elements in the downloaded file. "
750
750
  f"This most likely means that the download triggered an unhandled API response by GDrive. "
751
751
  f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
752
- f"the response:\n\n{text}"
752
+ f"the response:\n\n{text}",
753
+ stacklevel=2,
753
754
  )
754
755
 
755
756
  @classmethod
@@ -3,10 +3,10 @@
3
3
  from pathlib import Path
4
4
 
5
5
  import jax
6
- import jax.core
6
+ import jax.extend.core
7
7
 
8
8
 
9
- def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
9
+ def save_jaxpr_dot(closed_jaxpr: jax.extend.core.ClosedJaxpr, filename: str | Path) -> None:
10
10
  """Save the JAXPR to a DOT file.
11
11
 
12
12
  Example usage:
@@ -30,15 +30,15 @@ def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) ->
30
30
  with open(filename, "w") as f:
31
31
  f.write("digraph Jaxpr {\n")
32
32
 
33
- var_names: dict[jax.core.Var, str] = {}
33
+ var_names: dict[jax.extend.core.Var, str] = {}
34
34
  var_count = 0
35
35
 
36
- def get_var_name(var: jax.core.Var) -> str:
36
+ def get_var_name(var: jax.extend.core.Var) -> str:
37
37
  """Get a unique name for a variable."""
38
38
  nonlocal var_names, var_count
39
39
 
40
40
  # Handle Literal objects specially since they're not hashable
41
- if isinstance(var, jax.core.Literal):
41
+ if isinstance(var, jax.extend.core.Literal):
42
42
  # Create a name based on the literal value
43
43
  name = f"lit_{var.val}"
44
44
  return name
@@ -1,6 +1,7 @@
1
1
  """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
2
 
3
3
  import chex
4
+ import equinox as eqx
4
5
  import jax
5
6
  import jax.numpy as jnp
6
7
  from jax import Array
@@ -57,7 +58,7 @@ def pytree_has_nans(pytree: PyTree) -> Array:
57
58
 
58
59
  def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
59
60
  """Update a pytree based on a condition."""
60
- # Tricky, need use tree_map because where expects array leafs.
61
+ # Tricky, need use tree.map because where expects array leafs.
61
62
  return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
62
63
 
63
64
 
@@ -124,7 +125,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
124
125
  def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
125
126
  """Reshuffle a rollout array across arbitrary batch dimensions independently of each other."""
126
127
  rngs = jax.random.split(rng, len(batch_shape))
127
- perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape)]
128
+ perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape, strict=True)]
128
129
  # n-dimensional index grid from permutations
129
130
  idx_grids = jnp.meshgrid(*perms, indexing="ij")
130
131
 
@@ -236,3 +237,9 @@ def reshuffle_pytree_along_dims(
236
237
  return x
237
238
 
238
239
  return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
240
+
241
+
242
+ def get_pytree_param_count(pytree: PyTree) -> int:
243
+ """Calculates the total number of parameters in a PyTree."""
244
+ leaves, _ = jax.tree.flatten(pytree)
245
+ return sum(x.size for x in leaves if isinstance(x, jnp.ndarray) and eqx.is_inexact_array(x))
@@ -192,7 +192,7 @@ def render_text_blocks(
192
192
  if any(len(row) != len(blocks[0]) for row in blocks):
193
193
  raise ValueError("All rows must have the same number of blocks in order to align them")
194
194
  widths = [[max(len(line) for line in i.lines) if i.width is None else i.width for i in r] for r in blocks]
195
- row_widths = [max(i) for i in zip(*widths)]
195
+ row_widths = [max(i) for i in zip(*widths, strict=True)]
196
196
  for row in blocks:
197
197
  for i, block in enumerate(row):
198
198
  block.width = row_widths[i]
@@ -263,7 +263,7 @@ def render_text_blocks(
263
263
  if i >= len(block.lines)
264
264
  else colored(pad(block.lines[i], width, block.center), block.color, bold=block.bold)
265
265
  )
266
- for block, width in zip(row, get_widths(row))
266
+ for block, width in zip(row, get_widths(row), strict=True)
267
267
  ]
268
268
  )
269
269
  + " │"
@@ -133,12 +133,12 @@ class FrozenDict(Mapping[K, V]):
133
133
 
134
134
  @classmethod
135
135
  def tree_unflatten(cls, keys: tuple[K, ...], values: tuple[Any, ...]) -> "FrozenDict[K, V]":
136
- return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
136
+ return cls({k: v for k, v in zip(keys, values, strict=True)}, __unsafe_skip_copy__=True)
137
137
 
138
138
 
139
139
  def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
140
140
  if isinstance(x, FrozenDict):
141
- return jax.tree_util.tree_map(lambda y: y, x._dict)
141
+ return jax.tree.map(lambda y: y, x._dict)
142
142
  elif isinstance(x, dict):
143
143
  ys = {}
144
144
  for key, value in x.items():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes