xax 0.3.1__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {xax-0.3.1/xax.egg-info → xax-0.3.3}/PKG-INFO +1 -1
  2. {xax-0.3.1 → xax-0.3.3}/xax/__init__.py +6 -3
  3. {xax-0.3.1 → xax-0.3.3}/xax/core/state.py +1 -1
  4. {xax-0.3.1 → xax-0.3.3}/xax/nn/geom.py +82 -0
  5. {xax-0.3.1 → xax-0.3.3}/xax/task/base.py +3 -0
  6. {xax-0.3.1 → xax-0.3.3}/xax/task/launchers/single_process.py +2 -1
  7. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/artifacts.py +10 -1
  8. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/runnable.py +1 -2
  9. {xax-0.3.1 → xax-0.3.3}/xax/utils/logging.py +3 -1
  10. {xax-0.3.1 → xax-0.3.3/xax.egg-info}/PKG-INFO +1 -1
  11. {xax-0.3.1 → xax-0.3.3}/LICENSE +0 -0
  12. {xax-0.3.1 → xax-0.3.3}/MANIFEST.in +0 -0
  13. {xax-0.3.1 → xax-0.3.3}/README.md +0 -0
  14. {xax-0.3.1 → xax-0.3.3}/pyproject.toml +0 -0
  15. {xax-0.3.1 → xax-0.3.3}/setup.cfg +0 -0
  16. {xax-0.3.1 → xax-0.3.3}/setup.py +0 -0
  17. {xax-0.3.1 → xax-0.3.3}/xax/cli/__init__.py +0 -0
  18. {xax-0.3.1 → xax-0.3.3}/xax/cli/edit_config.py +0 -0
  19. {xax-0.3.1 → xax-0.3.3}/xax/core/__init__.py +0 -0
  20. {xax-0.3.1 → xax-0.3.3}/xax/core/conf.py +0 -0
  21. {xax-0.3.1 → xax-0.3.3}/xax/nn/__init__.py +0 -0
  22. {xax-0.3.1 → xax-0.3.3}/xax/nn/attention.py +0 -0
  23. {xax-0.3.1 → xax-0.3.3}/xax/nn/embeddings.py +0 -0
  24. {xax-0.3.1 → xax-0.3.3}/xax/nn/functions.py +0 -0
  25. {xax-0.3.1 → xax-0.3.3}/xax/nn/losses.py +0 -0
  26. {xax-0.3.1 → xax-0.3.3}/xax/nn/metrics.py +0 -0
  27. {xax-0.3.1 → xax-0.3.3}/xax/nn/parallel.py +0 -0
  28. {xax-0.3.1 → xax-0.3.3}/xax/nn/ssm.py +0 -0
  29. {xax-0.3.1 → xax-0.3.3}/xax/py.typed +0 -0
  30. {xax-0.3.1 → xax-0.3.3}/xax/requirements-dev.txt +0 -0
  31. {xax-0.3.1 → xax-0.3.3}/xax/requirements.txt +0 -0
  32. {xax-0.3.1 → xax-0.3.3}/xax/task/__init__.py +0 -0
  33. {xax-0.3.1 → xax-0.3.3}/xax/task/launchers/__init__.py +0 -0
  34. {xax-0.3.1 → xax-0.3.3}/xax/task/launchers/base.py +0 -0
  35. {xax-0.3.1 → xax-0.3.3}/xax/task/launchers/cli.py +0 -0
  36. {xax-0.3.1 → xax-0.3.3}/xax/task/logger.py +0 -0
  37. {xax-0.3.1 → xax-0.3.3}/xax/task/loggers/__init__.py +0 -0
  38. {xax-0.3.1 → xax-0.3.3}/xax/task/loggers/callback.py +0 -0
  39. {xax-0.3.1 → xax-0.3.3}/xax/task/loggers/json.py +0 -0
  40. {xax-0.3.1 → xax-0.3.3}/xax/task/loggers/state.py +0 -0
  41. {xax-0.3.1 → xax-0.3.3}/xax/task/loggers/stdout.py +0 -0
  42. {xax-0.3.1 → xax-0.3.3}/xax/task/loggers/tensorboard.py +0 -0
  43. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/__init__.py +0 -0
  44. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/checkpointing.py +0 -0
  45. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/compile.py +0 -0
  46. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/cpu_stats.py +0 -0
  47. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/data_loader.py +0 -0
  48. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/gpu_stats.py +0 -0
  49. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/logger.py +0 -0
  50. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/process.py +0 -0
  51. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/step_wrapper.py +0 -0
  52. {xax-0.3.1 → xax-0.3.3}/xax/task/mixins/train.py +0 -0
  53. {xax-0.3.1 → xax-0.3.3}/xax/task/script.py +0 -0
  54. {xax-0.3.1 → xax-0.3.3}/xax/task/task.py +0 -0
  55. {xax-0.3.1 → xax-0.3.3}/xax/utils/__init__.py +0 -0
  56. {xax-0.3.1 → xax-0.3.3}/xax/utils/data/__init__.py +0 -0
  57. {xax-0.3.1 → xax-0.3.3}/xax/utils/data/collate.py +0 -0
  58. {xax-0.3.1 → xax-0.3.3}/xax/utils/debugging.py +0 -0
  59. {xax-0.3.1 → xax-0.3.3}/xax/utils/experiments.py +0 -0
  60. {xax-0.3.1 → xax-0.3.3}/xax/utils/jax.py +0 -0
  61. {xax-0.3.1 → xax-0.3.3}/xax/utils/jaxpr.py +0 -0
  62. {xax-0.3.1 → xax-0.3.3}/xax/utils/numpy.py +0 -0
  63. {xax-0.3.1 → xax-0.3.3}/xax/utils/profile.py +0 -0
  64. {xax-0.3.1 → xax-0.3.3}/xax/utils/pytree.py +0 -0
  65. {xax-0.3.1 → xax-0.3.3}/xax/utils/tensorboard.py +0 -0
  66. {xax-0.3.1 → xax-0.3.3}/xax/utils/text.py +0 -0
  67. {xax-0.3.1 → xax-0.3.3}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.3.1 → xax-0.3.3}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.3.1 → xax-0.3.3}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.3.1 → xax-0.3.3}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.3.1 → xax-0.3.3}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.3.1 → xax-0.3.3}/xax.egg-info/entry_points.txt +0 -0
  73. {xax-0.3.1 → xax-0.3.3}/xax.egg-info/requires.txt +0 -0
  74. {xax-0.3.1 → xax-0.3.3}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.1"
15
+ __version__ = "0.3.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -42,11 +42,12 @@ __all__ = [
42
42
  "euler_to_quat",
43
43
  "get_projected_gravity_vector_from_quat",
44
44
  "normalize",
45
+ "quat_mul",
45
46
  "quat_to_euler",
46
47
  "quat_to_rotmat",
47
- "quat_mul",
48
48
  "rotate_vector_by_quat",
49
49
  "rotation6d_to_rotation_matrix",
50
+ "rotation_matrix_to_quat",
50
51
  "rotation_matrix_to_rotation6d",
51
52
  "cross_entropy",
52
53
  "cast_norm_type",
@@ -224,11 +225,12 @@ NAME_MAP: dict[str, str] = {
224
225
  "euler_to_quat": "nn.geom",
225
226
  "get_projected_gravity_vector_from_quat": "nn.geom",
226
227
  "normalize": "nn.geom",
228
+ "quat_mul": "nn.geom",
227
229
  "quat_to_euler": "nn.geom",
228
230
  "quat_to_rotmat": "nn.geom",
229
- "quat_mul": "nn.geom",
230
231
  "rotate_vector_by_quat": "nn.geom",
231
232
  "rotation6d_to_rotation_matrix": "nn.geom",
233
+ "rotation_matrix_to_quat": "nn.geom",
232
234
  "rotation_matrix_to_rotation6d": "nn.geom",
233
235
  "cross_entropy": "nn.losses",
234
236
  "cast_norm_type": "nn.metrics",
@@ -405,6 +407,7 @@ if IMPORT_ALL or TYPE_CHECKING:
405
407
  quat_to_rotmat,
406
408
  rotate_vector_by_quat,
407
409
  rotation6d_to_rotation_matrix,
410
+ rotation_matrix_to_quat,
408
411
  rotation_matrix_to_rotation6d,
409
412
  )
410
413
  from xax.nn.losses import cross_entropy
@@ -107,7 +107,7 @@ class State:
107
107
  @classmethod
108
108
  def from_dict(cls, **d: Unpack[StateDict]) -> "State":
109
109
  if "phase" in d:
110
- d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
110
+ d["_phase"] = _phase_to_int(d.pop("phase"))
111
111
 
112
112
  int32_arr = jnp.array(
113
113
  [
@@ -272,3 +272,85 @@ def quat_mul(q2: Array, q1: Array) -> Array:
272
272
  z = w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1
273
273
 
274
274
  return jnp.concatenate([w, x, y, z], axis=-1)
275
+
276
+
277
+ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
278
+ """Converts a rotation matrix to a unit quaternion ``(w, x, y, z)``.
279
+
280
+ Args:
281
+ rotation_matrix: The rotation matrix, shape ``(*, 3, 3)``.
282
+ eps: A small epsilon value to avoid division by zero when normalising.
283
+
284
+ Returns:
285
+ A quaternion with shape ``(*, 4)``.
286
+ """
287
+ chex.assert_shape(rotation_matrix, (..., 3, 3))
288
+
289
+ m00 = rotation_matrix[..., 0, 0]
290
+ m01 = rotation_matrix[..., 0, 1]
291
+ m02 = rotation_matrix[..., 0, 2]
292
+ m10 = rotation_matrix[..., 1, 0]
293
+ m11 = rotation_matrix[..., 1, 1]
294
+ m12 = rotation_matrix[..., 1, 2]
295
+ m20 = rotation_matrix[..., 2, 0]
296
+ m21 = rotation_matrix[..., 2, 1]
297
+ m22 = rotation_matrix[..., 2, 2]
298
+
299
+ trace = m00 + m11 + m22
300
+
301
+ # Case 0: trace is positive
302
+ s0 = jnp.sqrt(jnp.clip(trace + 1.0, a_min=0.0)) * 2.0 # S = 4 * qw
303
+ w0 = 0.25 * s0
304
+ x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
305
+ y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
306
+ z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
307
+
308
+ # Case 1: m00 is the largest diagonal term
309
+ s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, a_min=0.0)) * 2.0 # S = 4 * qx
310
+ w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
311
+ x1 = 0.25 * s1
312
+ y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
313
+ z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
314
+
315
+ # Case 2: m11 is the largest diagonal term
316
+ s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, a_min=0.0)) * 2.0 # S = 4 * qy
317
+ w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
318
+ x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
319
+ y2 = 0.25 * s2
320
+ z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
321
+
322
+ # Case 3: m22 is the largest diagonal term
323
+ s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, a_min=0.0)) * 2.0 # S = 4 * qz
324
+ w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
325
+ x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
326
+ y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
327
+ z3 = 0.25 * s3
328
+
329
+ cond0 = trace > 0.0
330
+ cond1 = (m00 > m11) & (m00 > m22)
331
+ cond2 = m11 > m22
332
+
333
+ w = jnp.where(
334
+ cond0,
335
+ w0,
336
+ jnp.where(cond1, w1, jnp.where(cond2, w2, w3)),
337
+ )
338
+ x = jnp.where(
339
+ cond0,
340
+ x0,
341
+ jnp.where(cond1, x1, jnp.where(cond2, x2, x3)),
342
+ )
343
+ y = jnp.where(
344
+ cond0,
345
+ y0,
346
+ jnp.where(cond1, y1, jnp.where(cond2, y2, y3)),
347
+ )
348
+ z = jnp.where(
349
+ cond0,
350
+ z0,
351
+ jnp.where(cond1, z1, jnp.where(cond2, z2, z3)),
352
+ )
353
+
354
+ quat = jnp.stack([w, x, y, z], axis=-1)
355
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
356
+ return quat
@@ -82,6 +82,9 @@ class BaseTask(Generic[Config]):
82
82
  def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
83
83
  return state
84
84
 
85
+ def add_logger_handlers(self, logger: logging.Logger) -> None:
86
+ pass
87
+
85
88
  @functools.cached_property
86
89
  def task_class_name(self) -> str:
87
90
  return self.__class__.__name__
@@ -15,8 +15,9 @@ def run_single_process_training(
15
15
  *cfgs: RawConfigType,
16
16
  use_cli: bool | list[str] = True,
17
17
  ) -> None:
18
- configure_logging()
18
+ logger = configure_logging()
19
19
  task_obj = task.get_task(*cfgs, use_cli=use_cli)
20
+ task_obj.add_logger_handlers(logger)
20
21
  task_obj.run()
21
22
 
22
23
 
@@ -14,7 +14,7 @@ from xax.core.state import State
14
14
  from xax.nn.parallel import is_master
15
15
  from xax.task.base import BaseConfig, BaseTask
16
16
  from xax.utils.experiments import stage_environment
17
- from xax.utils.logging import LOG_STATUS
17
+ from xax.utils.logging import LOG_STATUS, RankFilter
18
18
  from xax.utils.text import show_info
19
19
 
20
20
  logger = logging.getLogger(__name__)
@@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
24
24
  @dataclass
25
25
  class ArtifactsConfig(BaseConfig):
26
26
  exp_dir: str | None = field(None, help="The fixed experiment directory")
27
+ log_to_file: bool = field(True, help="If set, add a file handler to the logger to write all logs to the exp dir")
27
28
 
28
29
 
29
30
  Config = TypeVar("Config", bound=ArtifactsConfig)
@@ -39,6 +40,14 @@ class ArtifactsMixin(BaseTask[Config]):
39
40
  self._exp_dir = None
40
41
  self._stage_dir = None
41
42
 
43
+ def add_logger_handlers(self, logger: logging.Logger) -> None:
44
+ super().add_logger_handlers(logger)
45
+ if is_master() and self.config.log_to_file:
46
+ file_handler = logging.FileHandler(self.exp_dir / "logs.txt")
47
+ file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
48
+ file_handler.addFilter(RankFilter(rank=0))
49
+ logger.addHandler(file_handler)
50
+
42
51
  @functools.cached_property
43
52
  def run_dir(self) -> Path:
44
53
  run_dir = get_run_dir()
@@ -10,6 +10,7 @@ import jax
10
10
 
11
11
  from xax.task.base import BaseConfig, BaseTask, RawConfigType
12
12
  from xax.task.launchers.base import BaseLauncher
13
+ from xax.task.launchers.cli import CliLauncher
13
14
 
14
15
 
15
16
  @jax.tree_util.register_dataclass
@@ -45,8 +46,6 @@ class RunnableMixin(BaseTask[Config], ABC):
45
46
  use_cli: bool | list[str] = True,
46
47
  ) -> None:
47
48
  if launcher is None:
48
- from xax.task.launchers.cli import CliLauncher
49
-
50
49
  launcher = CliLauncher()
51
50
  launcher.launch(cls, *cfgs, use_cli=use_cli)
52
51
 
@@ -146,7 +146,7 @@ def configure_logging(
146
146
  rank: int | None = None,
147
147
  world_size: int | None = None,
148
148
  debug: bool | None = None,
149
- ) -> None:
149
+ ) -> logging.Logger:
150
150
  """Instantiates logging.
151
151
 
152
152
  This captures logs and reroutes them to the Toasts module, which is
@@ -186,6 +186,8 @@ def configure_logging(
186
186
  logging.getLogger("PIL").setLevel(logging.WARNING)
187
187
  logging.getLogger("torch").setLevel(logging.WARNING)
188
188
 
189
+ return root_logger
190
+
189
191
 
190
192
  def get_unused_port(default: int | None = None) -> int:
191
193
  """Returns an unused port number on the local machine.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes