xax 0.2.7__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.7/xax.egg-info → xax-0.2.8}/PKG-INFO +1 -1
  2. {xax-0.2.7 → xax-0.2.8}/pyproject.toml +1 -0
  3. {xax-0.2.7 → xax-0.2.8}/xax/__init__.py +50 -5
  4. {xax-0.2.7 → xax-0.2.8}/xax/core/conf.py +1 -1
  5. {xax-0.2.7 → xax-0.2.8}/xax/nn/equinox.py +6 -3
  6. {xax-0.2.7 → xax-0.2.8}/xax/nn/functions.py +7 -4
  7. {xax-0.2.7 → xax-0.2.8}/xax/nn/geom.py +49 -0
  8. {xax-0.2.7 → xax-0.2.8}/xax/task/base.py +2 -2
  9. {xax-0.2.7 → xax-0.2.8}/xax/task/logger.py +11 -6
  10. {xax-0.2.7 → xax-0.2.8}/xax/task/loggers/callback.py +6 -0
  11. {xax-0.2.7 → xax-0.2.8}/xax/task/loggers/json.py +13 -0
  12. {xax-0.2.7 → xax-0.2.8}/xax/task/loggers/state.py +26 -1
  13. {xax-0.2.7 → xax-0.2.8}/xax/task/loggers/stdout.py +4 -2
  14. {xax-0.2.7 → xax-0.2.8}/xax/task/loggers/tensorboard.py +19 -1
  15. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/artifacts.py +11 -8
  16. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/train.py +2 -7
  17. {xax-0.2.7 → xax-0.2.8}/xax/utils/experiments.py +2 -1
  18. {xax-0.2.7 → xax-0.2.8}/xax/utils/pytree.py +8 -1
  19. {xax-0.2.7 → xax-0.2.8}/xax/utils/text.py +2 -2
  20. {xax-0.2.7 → xax-0.2.8}/xax/utils/types/frozen_dict.py +1 -1
  21. {xax-0.2.7 → xax-0.2.8/xax.egg-info}/PKG-INFO +1 -1
  22. {xax-0.2.7 → xax-0.2.8}/LICENSE +0 -0
  23. {xax-0.2.7 → xax-0.2.8}/MANIFEST.in +0 -0
  24. {xax-0.2.7 → xax-0.2.8}/README.md +0 -0
  25. {xax-0.2.7 → xax-0.2.8}/setup.cfg +0 -0
  26. {xax-0.2.7 → xax-0.2.8}/setup.py +0 -0
  27. {xax-0.2.7 → xax-0.2.8}/xax/core/__init__.py +0 -0
  28. {xax-0.2.7 → xax-0.2.8}/xax/core/state.py +0 -0
  29. {xax-0.2.7 → xax-0.2.8}/xax/nn/__init__.py +0 -0
  30. {xax-0.2.7 → xax-0.2.8}/xax/nn/embeddings.py +0 -0
  31. {xax-0.2.7 → xax-0.2.8}/xax/nn/export.py +0 -0
  32. {xax-0.2.7 → xax-0.2.8}/xax/nn/losses.py +0 -0
  33. {xax-0.2.7 → xax-0.2.8}/xax/nn/norm.py +0 -0
  34. {xax-0.2.7 → xax-0.2.8}/xax/nn/parallel.py +0 -0
  35. {xax-0.2.7 → xax-0.2.8}/xax/nn/ssm.py +0 -0
  36. {xax-0.2.7 → xax-0.2.8}/xax/py.typed +0 -0
  37. {xax-0.2.7 → xax-0.2.8}/xax/requirements-dev.txt +0 -0
  38. {xax-0.2.7 → xax-0.2.8}/xax/requirements.txt +0 -0
  39. {xax-0.2.7 → xax-0.2.8}/xax/task/__init__.py +0 -0
  40. {xax-0.2.7 → xax-0.2.8}/xax/task/launchers/__init__.py +0 -0
  41. {xax-0.2.7 → xax-0.2.8}/xax/task/launchers/base.py +0 -0
  42. {xax-0.2.7 → xax-0.2.8}/xax/task/launchers/cli.py +0 -0
  43. {xax-0.2.7 → xax-0.2.8}/xax/task/launchers/single_process.py +0 -0
  44. {xax-0.2.7 → xax-0.2.8}/xax/task/loggers/__init__.py +0 -0
  45. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/__init__.py +0 -0
  46. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/checkpointing.py +0 -0
  47. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/compile.py +0 -0
  48. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/cpu_stats.py +0 -0
  49. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/data_loader.py +0 -0
  50. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/gpu_stats.py +0 -0
  51. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/logger.py +0 -0
  52. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/process.py +0 -0
  53. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/runnable.py +0 -0
  54. {xax-0.2.7 → xax-0.2.8}/xax/task/mixins/step_wrapper.py +0 -0
  55. {xax-0.2.7 → xax-0.2.8}/xax/task/script.py +0 -0
  56. {xax-0.2.7 → xax-0.2.8}/xax/task/task.py +0 -0
  57. {xax-0.2.7 → xax-0.2.8}/xax/utils/__init__.py +0 -0
  58. {xax-0.2.7 → xax-0.2.8}/xax/utils/data/__init__.py +0 -0
  59. {xax-0.2.7 → xax-0.2.8}/xax/utils/data/collate.py +0 -0
  60. {xax-0.2.7 → xax-0.2.8}/xax/utils/debugging.py +0 -0
  61. {xax-0.2.7 → xax-0.2.8}/xax/utils/jax.py +0 -0
  62. {xax-0.2.7 → xax-0.2.8}/xax/utils/jaxpr.py +0 -0
  63. {xax-0.2.7 → xax-0.2.8}/xax/utils/logging.py +0 -0
  64. {xax-0.2.7 → xax-0.2.8}/xax/utils/numpy.py +0 -0
  65. {xax-0.2.7 → xax-0.2.8}/xax/utils/profile.py +0 -0
  66. {xax-0.2.7 → xax-0.2.8}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.2.7 → xax-0.2.8}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.2.7 → xax-0.2.8}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.7 → xax-0.2.8}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.7 → xax-0.2.8}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.7 → xax-0.2.8}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.7 → xax-0.2.8}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.7
3
+ Version: 0.2.8
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -57,6 +57,7 @@ target-version = "py311"
57
57
 
58
58
  select = [
59
59
  "ANN",
60
+ "B",
60
61
  "D",
61
62
  "E",
62
63
  "F",
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.7"
15
+ __version__ = "0.2.8"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -43,9 +43,12 @@ __all__ = [
43
43
  "cubic_bezier_interpolation",
44
44
  "euler_to_quat",
45
45
  "get_projected_gravity_vector_from_quat",
46
+ "normalize",
46
47
  "quat_to_euler",
47
48
  "quat_to_rotmat",
48
49
  "rotate_vector_by_quat",
50
+ "rotation6d_to_rotation_matrix",
51
+ "rotation_matrix_to_rotation6d",
49
52
  "cross_entropy",
50
53
  "cast_norm_type",
51
54
  "get_norm",
@@ -57,8 +60,18 @@ __all__ = [
57
60
  "BaseLauncher",
58
61
  "CliLauncher",
59
62
  "SingleProcessLauncher",
63
+ "LogDistribution",
64
+ "LogError",
65
+ "LogErrorSummary",
66
+ "LogGraph",
67
+ "LogHistogram",
60
68
  "LogImage",
61
69
  "LogLine",
70
+ "LogMesh",
71
+ "LogPing",
72
+ "LogScalar",
73
+ "LogStatus",
74
+ "LogVideo",
62
75
  "Logger",
63
76
  "LoggerImpl",
64
77
  "CallbackLogger",
@@ -72,7 +85,6 @@ __all__ = [
72
85
  "GPUStatsOptions",
73
86
  "StepContext",
74
87
  "ValidStepTimer",
75
- "get_param_count",
76
88
  "Script",
77
89
  "ScriptConfig",
78
90
  "Config",
@@ -117,6 +129,7 @@ __all__ = [
117
129
  "compute_nan_ratio",
118
130
  "flatten_array",
119
131
  "flatten_pytree",
132
+ "get_pytree_param_count",
120
133
  "pytree_has_nans",
121
134
  "reshuffle_pytree",
122
135
  "reshuffle_pytree_along_dims",
@@ -209,9 +222,12 @@ NAME_MAP: dict[str, str] = {
209
222
  "cubic_bezier_interpolation": "nn.geom",
210
223
  "euler_to_quat": "nn.geom",
211
224
  "get_projected_gravity_vector_from_quat": "nn.geom",
225
+ "normalize": "nn.geom",
212
226
  "quat_to_euler": "nn.geom",
213
227
  "quat_to_rotmat": "nn.geom",
214
228
  "rotate_vector_by_quat": "nn.geom",
229
+ "rotation6d_to_rotation_matrix": "nn.geom",
230
+ "rotation_matrix_to_rotation6d": "nn.geom",
215
231
  "cross_entropy": "nn.losses",
216
232
  "cast_norm_type": "nn.norm",
217
233
  "get_norm": "nn.norm",
@@ -223,8 +239,18 @@ NAME_MAP: dict[str, str] = {
223
239
  "BaseLauncher": "task.launchers.base",
224
240
  "CliLauncher": "task.launchers.cli",
225
241
  "SingleProcessLauncher": "task.launchers.single_process",
242
+ "LogDistribution": "task.logger",
243
+ "LogError": "task.logger",
244
+ "LogErrorSummary": "task.logger",
245
+ "LogGraph": "task.logger",
246
+ "LogHistogram": "task.logger",
226
247
  "LogImage": "task.logger",
227
248
  "LogLine": "task.logger",
249
+ "LogMesh": "task.logger",
250
+ "LogPing": "task.logger",
251
+ "LogScalar": "task.logger",
252
+ "LogStatus": "task.logger",
253
+ "LogVideo": "task.logger",
228
254
  "Logger": "task.logger",
229
255
  "LoggerImpl": "task.logger",
230
256
  "CallbackLogger": "task.loggers.callback",
@@ -238,7 +264,6 @@ NAME_MAP: dict[str, str] = {
238
264
  "GPUStatsOptions": "task.mixins.gpu_stats",
239
265
  "StepContext": "task.mixins.step_wrapper",
240
266
  "ValidStepTimer": "task.mixins.train",
241
- "get_param_count": "task.mixins.train",
242
267
  "Script": "task.script",
243
268
  "ScriptConfig": "task.script",
244
269
  "Config": "task.task",
@@ -283,6 +308,7 @@ NAME_MAP: dict[str, str] = {
283
308
  "compute_nan_ratio": "utils.pytree",
284
309
  "flatten_array": "utils.pytree",
285
310
  "flatten_pytree": "utils.pytree",
311
+ "get_param_count": "utils.pytree",
286
312
  "pytree_has_nans": "utils.pytree",
287
313
  "reshuffle_pytree": "utils.pytree",
288
314
  "reshuffle_pytree_along_dims": "utils.pytree",
@@ -376,9 +402,12 @@ if IMPORT_ALL or TYPE_CHECKING:
376
402
  cubic_bezier_interpolation,
377
403
  euler_to_quat,
378
404
  get_projected_gravity_vector_from_quat,
405
+ normalize,
379
406
  quat_to_euler,
380
407
  quat_to_rotmat,
381
408
  rotate_vector_by_quat,
409
+ rotation6d_to_rotation_matrix,
410
+ rotation_matrix_to_rotation6d,
382
411
  )
383
412
  from xax.nn.losses import cross_entropy
384
413
  from xax.nn.norm import NormType, cast_norm_type, get_norm
@@ -388,7 +417,22 @@ if IMPORT_ALL or TYPE_CHECKING:
388
417
  from xax.task.launchers.base import BaseLauncher
389
418
  from xax.task.launchers.cli import CliLauncher
390
419
  from xax.task.launchers.single_process import SingleProcessLauncher
391
- from xax.task.logger import Logger, LoggerImpl, LogImage, LogLine
420
+ from xax.task.logger import (
421
+ LogDistribution,
422
+ LogError,
423
+ LogErrorSummary,
424
+ Logger,
425
+ LoggerImpl,
426
+ LogGraph,
427
+ LogHistogram,
428
+ LogImage,
429
+ LogLine,
430
+ LogMesh,
431
+ LogPing,
432
+ LogScalar,
433
+ LogStatus,
434
+ LogVideo,
435
+ )
392
436
  from xax.task.loggers.callback import CallbackLogger
393
437
  from xax.task.loggers.json import JsonLogger
394
438
  from xax.task.loggers.state import StateLogger
@@ -399,7 +443,7 @@ if IMPORT_ALL or TYPE_CHECKING:
399
443
  from xax.task.mixins.data_loader import DataloaderConfig
400
444
  from xax.task.mixins.gpu_stats import GPUStatsOptions
401
445
  from xax.task.mixins.step_wrapper import StepContext
402
- from xax.task.mixins.train import Batch, Output, ValidStepTimer, get_param_count
446
+ from xax.task.mixins.train import Batch, Output, ValidStepTimer
403
447
  from xax.task.script import Script, ScriptConfig
404
448
  from xax.task.task import Config, Task
405
449
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
@@ -444,6 +488,7 @@ if IMPORT_ALL or TYPE_CHECKING:
444
488
  compute_nan_ratio,
445
489
  flatten_array,
446
490
  flatten_pytree,
491
+ get_pytree_param_count,
447
492
  pytree_has_nans,
448
493
  reshuffle_pytree,
449
494
  reshuffle_pytree_along_dims,
@@ -26,7 +26,7 @@ def field(value: FieldType, **kwargs: str) -> FieldType:
26
26
  metadata: dict[str, Any] = {}
27
27
  metadata.update(kwargs)
28
28
 
29
- if hasattr(value, "__call__"):
29
+ if hasattr(value, "__call__"): # noqa: B004
30
30
  return field_base(default_factory=value, metadata=metadata)
31
31
  if value.__class__.__hash__ is None:
32
32
  return field_base(default_factory=lambda: value, metadata=metadata)
@@ -68,8 +68,8 @@ def _infer_activation(activation: ActivationFunction) -> Callable:
68
68
  return lambda x: x
69
69
  try:
70
70
  return getattr(jax.nn, activation)
71
- except AttributeError:
72
- raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
71
+ except AttributeError as err:
72
+ raise ValueError(f"Activation function `{activation}` not found in `jax.nn`") from err
73
73
 
74
74
 
75
75
  def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
@@ -100,7 +100,7 @@ def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.ML
100
100
  def export_eqx_mlp(
101
101
  model: eqx.nn.MLP,
102
102
  output_path: str | Path,
103
- dtype: jax.numpy.dtype = eqx._misc.default_floating_dtype(),
103
+ dtype: jax.numpy.dtype | None = None,
104
104
  ) -> None:
105
105
  """Serialize an Equinox MLP to a .eqx file.
106
106
 
@@ -109,6 +109,9 @@ def export_eqx_mlp(
109
109
  output_path: The path to save the exported model.
110
110
  dtype: The dtype of the model.
111
111
  """
112
+ if dtype is None:
113
+ dtype = eqx._misc.default_floating_dtype()
114
+
112
115
  activation = model.activation.__name__
113
116
  final_activation = model.final_activation.__name__
114
117
 
@@ -58,13 +58,16 @@ def recursive_chunk(item: Any, num_chunks: int, dim: int = 0) -> Iterable[Any]:
58
58
  yield from np.array_split(item, num_chunks, axis=dim)
59
59
  elif is_dataclass(item):
60
60
  yield from (
61
- item.__class__(**{k: i for k, i in zip(item.__dict__, ii)})
62
- for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()))
61
+ item.__class__(**{k: i for k, i in zip(item.__dict__, ii, strict=True)})
62
+ for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()), strict=False)
63
63
  )
64
64
  elif isinstance(item, Mapping):
65
- yield from (dict(zip(item, ii)) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values())))
65
+ yield from (
66
+ dict(zip(item, ii, strict=False))
67
+ for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values()), strict=False)
68
+ )
66
69
  elif isinstance(item, Sequence):
67
- yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item)))
70
+ yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item), strict=False))
68
71
  else:
69
72
  yield from (item for _ in range(num_chunks))
70
73
 
@@ -1,5 +1,6 @@
1
1
  """Defines geometry functions."""
2
2
 
3
+ import chex
3
4
  from jax import numpy as jnp
4
5
  from jaxtyping import Array
5
6
 
@@ -211,3 +212,51 @@ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
211
212
  ],
212
213
  axis=-2,
213
214
  )
215
+
216
+
217
+ def normalize(v: jnp.ndarray, axis: int = -1, eps: float = 1e-8) -> jnp.ndarray:
218
+ norm = jnp.linalg.norm(v, axis=axis, keepdims=True)
219
+ return v / jnp.clip(norm, a_min=eps)
220
+
221
+
222
+ def rotation6d_to_rotation_matrix(r6d: jnp.ndarray) -> jnp.ndarray:
223
+ """Convert 6D rotation representation to rotation matrix.
224
+
225
+ From https://arxiv.org/pdf/1812.07035, Appendix B
226
+
227
+ Args:
228
+ r6d: The 6D rotation representation, shape (*, 6).
229
+
230
+ Returns:
231
+ The rotation matrix, shape (*, 3, 3).
232
+ """
233
+ chex.assert_shape(r6d, (..., 6))
234
+ shape = r6d.shape
235
+ flat = r6d.reshape(-1, 6)
236
+ a_1 = flat[:, 0:3]
237
+ a_2 = flat[:, 3:6]
238
+
239
+ b_1 = normalize(a_1, axis=-1)
240
+
241
+ # Reordered Gram-Schmidt orthonormalization.
242
+ b_3 = normalize(jnp.cross(b_1, a_2), axis=-1)
243
+ b_2 = jnp.cross(b_3, b_1)
244
+
245
+ rotation_matrix = jnp.stack([b_1, b_2, b_3], axis=-1)
246
+ return rotation_matrix.reshape(shape[:-1] + (3, 3))
247
+
248
+
249
+ def rotation_matrix_to_rotation6d(rotation_matrix: jnp.ndarray) -> jnp.ndarray:
250
+ """Convert rotation matrix to 6D rotation representation.
251
+
252
+ Args:
253
+ rotation_matrix: The rotation matrix, shape (*, 3, 3).
254
+
255
+ Returns:
256
+ The 6D rotation representation, shape (*, 6).
257
+ """
258
+ chex.assert_shape(rotation_matrix, (..., 3, 3))
259
+ shape = rotation_matrix.shape
260
+ # Simply concatenate a1 and a2 from SO(3)
261
+ r6d = jnp.concatenate([rotation_matrix[..., 0], rotation_matrix[..., 1]], axis=-1)
262
+ return r6d.reshape(shape[:-2] + (6,))
@@ -184,8 +184,8 @@ class BaseTask(Generic[Config]):
184
184
 
185
185
  # Attempts to load any paths as configs.
186
186
  is_path = [Path(arg).is_file() or (task_path / arg).is_file() for arg in args]
187
- paths = [arg for arg, is_path in zip(args, is_path) if is_path]
188
- non_paths = [arg for arg, is_path in zip(args, is_path) if not is_path]
187
+ paths = [arg for arg, is_path in zip(args, is_path, strict=True) if is_path]
188
+ non_paths = [arg for arg, is_path in zip(args, is_path, strict=True) if not is_path]
189
189
  if paths:
190
190
  cfg = OmegaConf.merge(cfg, *(get_config(path, task_path) for path in paths))
191
191
  cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(non_paths))
@@ -462,11 +462,11 @@ class LoggerImpl(ABC):
462
462
 
463
463
  self.tickers = {phase: IntervalTicker(log_interval_seconds) for phase in get_args(Phase)}
464
464
 
465
- def start(self) -> None:
466
- pass
465
+ @abstractmethod
466
+ def start(self) -> None: ...
467
467
 
468
- def stop(self) -> None:
469
- pass
468
+ @abstractmethod
469
+ def stop(self) -> None: ...
470
470
 
471
471
  @abstractmethod
472
472
  def write(self, line: LogLine) -> None:
@@ -476,6 +476,7 @@ class LoggerImpl(ABC):
476
476
  line: The line to write.
477
477
  """
478
478
 
479
+ @abstractmethod
479
480
  def write_error_summary(self, error_summary: LogErrorSummary) -> None:
480
481
  """Handles writing an error summary.
481
482
 
@@ -483,6 +484,7 @@ class LoggerImpl(ABC):
483
484
  error_summary: The error summary to write.
484
485
  """
485
486
 
487
+ @abstractmethod
486
488
  def write_error(self, error: LogError) -> None:
487
489
  """Handles writing an error line.
488
490
 
@@ -490,6 +492,7 @@ class LoggerImpl(ABC):
490
492
  error: The error information to write.
491
493
  """
492
494
 
495
+ @abstractmethod
493
496
  def write_status(self, status: LogStatus) -> None:
494
497
  """Handles writing a status line.
495
498
 
@@ -497,6 +500,7 @@ class LoggerImpl(ABC):
497
500
  status: The status to write.
498
501
  """
499
502
 
503
+ @abstractmethod
500
504
  def write_ping(self, ping: LogPing) -> None:
501
505
  """Handles writing a ping line.
502
506
 
@@ -504,6 +508,7 @@ class LoggerImpl(ABC):
504
508
  ping: The ping to write.
505
509
  """
506
510
 
511
+ @abstractmethod
507
512
  def log_file(self, name: str, contents: str) -> None:
508
513
  """Logs a large text file.
509
514
 
@@ -621,7 +626,7 @@ class Logger:
621
626
  return
622
627
  line = self.pack(state)
623
628
  self.clear()
624
- for lg in (lg for lg, should_log in zip(self.loggers, should_log) if should_log):
629
+ for lg in (lg for lg, should_log in zip(self.loggers, should_log, strict=False) if should_log):
625
630
  lg.write(line)
626
631
 
627
632
  def write_error_summary(self, error_summary: str) -> None:
@@ -1045,7 +1050,7 @@ class Logger:
1045
1050
  line_spacing=line_spacing,
1046
1051
  centered=centered,
1047
1052
  )
1048
- for img, label in zip(images, labels)
1053
+ for img, label in zip(images, labels, strict=True)
1049
1054
  ]
1050
1055
  tiled = tile_images([img.image for img in labeled], sep)
1051
1056
 
@@ -25,6 +25,12 @@ class CallbackLogger(LoggerImpl):
25
25
  self.ping_callback = ping_callback
26
26
  self.file_callback = file_callback
27
27
 
28
+ def start(self) -> None:
29
+ pass
30
+
31
+ def stop(self) -> None:
32
+ pass
33
+
28
34
  def write(self, line: LogLine) -> None:
29
35
  self.callback(line)
30
36
 
@@ -8,6 +8,7 @@ from jaxtyping import Array
8
8
 
9
9
  from xax.task.logger import (
10
10
  LogError,
11
+ LogErrorSummary,
11
12
  LoggerImpl,
12
13
  LogLine,
13
14
  LogPing,
@@ -57,6 +58,12 @@ class JsonLogger(LoggerImpl):
57
58
  self.line_sep = line_sep
58
59
  self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
59
60
 
61
+ def start(self) -> None:
62
+ pass
63
+
64
+ def stop(self) -> None:
65
+ pass
66
+
60
67
  @property
61
68
  def fp(self) -> TextIO:
62
69
  return self.log_stream
@@ -87,6 +94,12 @@ class JsonLogger(LoggerImpl):
87
94
  if self.flush_immediately:
88
95
  self.fp.flush()
89
96
 
97
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
98
+ pass
99
+
100
+ def log_file(self, name: str, contents: str) -> None:
101
+ pass
102
+
90
103
  def write_error(self, error: LogError) -> None:
91
104
  self.err_fp.write(error.message)
92
105
  if error.location is not None:
@@ -3,7 +3,14 @@
3
3
  from pathlib import Path
4
4
  from typing import Literal
5
5
 
6
- from xax.task.logger import LoggerImpl, LogLine
6
+ from xax.task.logger import (
7
+ LogError,
8
+ LogErrorSummary,
9
+ LoggerImpl,
10
+ LogLine,
11
+ LogPing,
12
+ LogStatus,
13
+ )
7
14
 
8
15
 
9
16
  class StateLogger(LoggerImpl):
@@ -30,3 +37,21 @@ class StateLogger(LoggerImpl):
30
37
 
31
38
  def write(self, line: LogLine) -> None:
32
39
  pass
40
+
41
+ def start(self) -> None:
42
+ pass
43
+
44
+ def stop(self) -> None:
45
+ pass
46
+
47
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
48
+ pass
49
+
50
+ def write_error(self, error: LogError) -> None:
51
+ pass
52
+
53
+ def write_status(self, status: LogStatus) -> None:
54
+ pass
55
+
56
+ def write_ping(self, ping: LogPing) -> None:
57
+ pass
@@ -79,11 +79,13 @@ class StdoutLogger(LoggerImpl):
79
79
  self.error_summary: tuple[str, datetime.datetime] | None = None
80
80
 
81
81
  def start(self) -> None:
82
- return super().start()
82
+ pass
83
83
 
84
84
  def stop(self) -> None:
85
85
  self.write_queues()
86
- return super().stop()
86
+
87
+ def log_file(self, name: str, contents: str) -> None:
88
+ pass
87
89
 
88
90
  def write_separator(self) -> None:
89
91
  self.write_fp.write("\033[2J\033[H")
@@ -12,7 +12,7 @@ from typing import TypeVar
12
12
 
13
13
  from xax.core.state import Phase
14
14
  from xax.nn.parallel import is_master
15
- from xax.task.logger import LoggerImpl, LogLine
15
+ from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
16
16
  from xax.utils.jax import as_float
17
17
  from xax.utils.logging import LOG_STATUS, port_is_busy
18
18
  from xax.utils.tensorboard import TensorboardWriter, TensorboardWriters
@@ -236,3 +236,21 @@ class TensorboardLogger(LoggerImpl):
236
236
  for name, contents in self.files.items():
237
237
  writer.add_text(name, contents)
238
238
  self.files.clear()
239
+
240
+ def start(self) -> None:
241
+ pass
242
+
243
+ def stop(self) -> None:
244
+ pass
245
+
246
+ def write_error(self, error: LogError) -> None:
247
+ pass
248
+
249
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
250
+ pass
251
+
252
+ def write_ping(self, ping: LogPing) -> None:
253
+ pass
254
+
255
+ def write_status(self, status: LogStatus) -> None:
256
+ pass
@@ -31,11 +31,13 @@ Config = TypeVar("Config", bound=ArtifactsConfig)
31
31
 
32
32
  class ArtifactsMixin(BaseTask[Config]):
33
33
  _exp_dir: Path | None
34
+ _stage_dir: Path | None
34
35
 
35
36
  def __init__(self, config: Config) -> None:
36
37
  super().__init__(config)
37
38
 
38
39
  self._exp_dir = None
40
+ self._stage_dir = None
39
41
 
40
42
  @functools.cached_property
41
43
  def run_dir(self) -> Path:
@@ -75,15 +77,16 @@ class ArtifactsMixin(BaseTask[Config]):
75
77
  logger.log(LOG_STATUS, self._exp_dir)
76
78
  return self._exp_dir
77
79
 
78
- @functools.lru_cache(maxsize=None)
79
80
  def stage_environment(self) -> Path | None:
80
- stage_dir = (self.exp_dir / "code").resolve()
81
- try:
82
- stage_environment(self, stage_dir)
83
- except Exception:
84
- logger.exception("Failed to stage environment!")
85
- return None
86
- return stage_dir
81
+ if self._stage_dir is None:
82
+ stage_dir = (self.exp_dir / "code").resolve()
83
+ try:
84
+ stage_environment(self, stage_dir)
85
+ except Exception:
86
+ logger.exception("Failed to stage environment!")
87
+ return None
88
+ self._stage_dir = stage_dir
89
+ return self._stage_dir
87
90
 
88
91
  def on_training_end(self, state: State) -> State:
89
92
  state = super().on_training_end(state)
@@ -57,6 +57,7 @@ from xax.utils.experiments import (
57
57
  )
58
58
  from xax.utils.jax import jit as xax_jit
59
59
  from xax.utils.logging import LOG_PING, LOG_STATUS
60
+ from xax.utils.pytree import get_pytree_param_count
60
61
  from xax.utils.text import highlight_exception_message, show_info
61
62
  from xax.utils.types.frozen_dict import FrozenDict
62
63
 
@@ -96,12 +97,6 @@ def batches_per_step_schedule(schedule: list[int] | None) -> list[int] | None:
96
97
  return list(itertools.accumulate([0] + schedule))
97
98
 
98
99
 
99
- def get_param_count(pytree: PyTree) -> int:
100
- """Calculates the total number of parameters in a PyTree."""
101
- leaves, _ = jax.tree.flatten(pytree)
102
- return sum(x.size for x in leaves if isinstance(x, jnp.ndarray))
103
-
104
-
105
100
  class ValidStepTimer:
106
101
  def __init__(
107
102
  self,
@@ -690,7 +685,7 @@ class TrainMixin(
690
685
  self.logger.log_file("info.json", get_info_json())
691
686
 
692
687
  def log_model_size(self, model: PyTree) -> None:
693
- logger.info("Model size: %s", f"{get_param_count(model):,}")
688
+ logger.info("Model size: %s", f"{get_pytree_param_count(model):,}")
694
689
 
695
690
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
696
691
  return eqx.is_inexact_array(item)
@@ -749,7 +749,8 @@ class BaseFileDownloader(ABC):
749
749
  f"We detected some HTML elements in the downloaded file. "
750
750
  f"This most likely means that the download triggered an unhandled API response by GDrive. "
751
751
  f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
752
- f"the response:\n\n{text}"
752
+ f"the response:\n\n{text}",
753
+ stacklevel=2,
753
754
  )
754
755
 
755
756
  @classmethod
@@ -1,6 +1,7 @@
1
1
  """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
2
 
3
3
  import chex
4
+ import equinox as eqx
4
5
  import jax
5
6
  import jax.numpy as jnp
6
7
  from jax import Array
@@ -124,7 +125,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
124
125
  def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
125
126
  """Reshuffle a rollout array across arbitrary batch dimensions independently of each other."""
126
127
  rngs = jax.random.split(rng, len(batch_shape))
127
- perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape)]
128
+ perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape, strict=True)]
128
129
  # n-dimensional index grid from permutations
129
130
  idx_grids = jnp.meshgrid(*perms, indexing="ij")
130
131
 
@@ -236,3 +237,9 @@ def reshuffle_pytree_along_dims(
236
237
  return x
237
238
 
238
239
  return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
240
+
241
+
242
+ def get_pytree_param_count(pytree: PyTree) -> int:
243
+ """Calculates the total number of parameters in a PyTree."""
244
+ leaves, _ = jax.tree.flatten(pytree)
245
+ return sum(x.size for x in leaves if isinstance(x, jnp.ndarray) and eqx.is_inexact_array(x))
@@ -192,7 +192,7 @@ def render_text_blocks(
192
192
  if any(len(row) != len(blocks[0]) for row in blocks):
193
193
  raise ValueError("All rows must have the same number of blocks in order to align them")
194
194
  widths = [[max(len(line) for line in i.lines) if i.width is None else i.width for i in r] for r in blocks]
195
- row_widths = [max(i) for i in zip(*widths)]
195
+ row_widths = [max(i) for i in zip(*widths, strict=True)]
196
196
  for row in blocks:
197
197
  for i, block in enumerate(row):
198
198
  block.width = row_widths[i]
@@ -263,7 +263,7 @@ def render_text_blocks(
263
263
  if i >= len(block.lines)
264
264
  else colored(pad(block.lines[i], width, block.center), block.color, bold=block.bold)
265
265
  )
266
- for block, width in zip(row, get_widths(row))
266
+ for block, width in zip(row, get_widths(row), strict=True)
267
267
  ]
268
268
  )
269
269
  + " │"
@@ -133,7 +133,7 @@ class FrozenDict(Mapping[K, V]):
133
133
 
134
134
  @classmethod
135
135
  def tree_unflatten(cls, keys: tuple[K, ...], values: tuple[Any, ...]) -> "FrozenDict[K, V]":
136
- return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
136
+ return cls({k: v for k, v in zip(keys, values, strict=True)}, __unsafe_skip_copy__=True)
137
137
 
138
138
 
139
139
  def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.7
3
+ Version: 0.2.8
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes