xax 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.1"
15
+ __version__ = "0.3.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -42,11 +42,12 @@ __all__ = [
42
42
  "euler_to_quat",
43
43
  "get_projected_gravity_vector_from_quat",
44
44
  "normalize",
45
+ "quat_mul",
45
46
  "quat_to_euler",
46
47
  "quat_to_rotmat",
47
- "quat_mul",
48
48
  "rotate_vector_by_quat",
49
49
  "rotation6d_to_rotation_matrix",
50
+ "rotation_matrix_to_quat",
50
51
  "rotation_matrix_to_rotation6d",
51
52
  "cross_entropy",
52
53
  "cast_norm_type",
@@ -224,11 +225,12 @@ NAME_MAP: dict[str, str] = {
224
225
  "euler_to_quat": "nn.geom",
225
226
  "get_projected_gravity_vector_from_quat": "nn.geom",
226
227
  "normalize": "nn.geom",
228
+ "quat_mul": "nn.geom",
227
229
  "quat_to_euler": "nn.geom",
228
230
  "quat_to_rotmat": "nn.geom",
229
- "quat_mul": "nn.geom",
230
231
  "rotate_vector_by_quat": "nn.geom",
231
232
  "rotation6d_to_rotation_matrix": "nn.geom",
233
+ "rotation_matrix_to_quat": "nn.geom",
232
234
  "rotation_matrix_to_rotation6d": "nn.geom",
233
235
  "cross_entropy": "nn.losses",
234
236
  "cast_norm_type": "nn.metrics",
@@ -405,6 +407,7 @@ if IMPORT_ALL or TYPE_CHECKING:
405
407
  quat_to_rotmat,
406
408
  rotate_vector_by_quat,
407
409
  rotation6d_to_rotation_matrix,
410
+ rotation_matrix_to_quat,
408
411
  rotation_matrix_to_rotation6d,
409
412
  )
410
413
  from xax.nn.losses import cross_entropy
xax/core/state.py CHANGED
@@ -107,7 +107,7 @@ class State:
107
107
  @classmethod
108
108
  def from_dict(cls, **d: Unpack[StateDict]) -> "State":
109
109
  if "phase" in d:
110
- d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
110
+ d["_phase"] = _phase_to_int(d.pop("phase"))
111
111
 
112
112
  int32_arr = jnp.array(
113
113
  [
xax/nn/geom.py CHANGED
@@ -272,3 +272,85 @@ def quat_mul(q2: Array, q1: Array) -> Array:
272
272
  z = w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1
273
273
 
274
274
  return jnp.concatenate([w, x, y, z], axis=-1)
275
+
276
+
277
+ def rotation_matrix_to_quat(rotation_matrix: Array, eps: float = 1e-6) -> Array:
278
+ """Converts a rotation matrix to a unit quaternion ``(w, x, y, z)``.
279
+
280
+ Args:
281
+ rotation_matrix: The rotation matrix, shape ``(*, 3, 3)``.
282
+ eps: A small epsilon value to avoid division by zero when normalising.
283
+
284
+ Returns:
285
+ A quaternion with shape ``(*, 4)``.
286
+ """
287
+ chex.assert_shape(rotation_matrix, (..., 3, 3))
288
+
289
+ m00 = rotation_matrix[..., 0, 0]
290
+ m01 = rotation_matrix[..., 0, 1]
291
+ m02 = rotation_matrix[..., 0, 2]
292
+ m10 = rotation_matrix[..., 1, 0]
293
+ m11 = rotation_matrix[..., 1, 1]
294
+ m12 = rotation_matrix[..., 1, 2]
295
+ m20 = rotation_matrix[..., 2, 0]
296
+ m21 = rotation_matrix[..., 2, 1]
297
+ m22 = rotation_matrix[..., 2, 2]
298
+
299
+ trace = m00 + m11 + m22
300
+
301
+ # Case 0: trace is positive
302
+ s0 = jnp.sqrt(jnp.clip(trace + 1.0, a_min=0.0)) * 2.0 # S = 4 * qw
303
+ w0 = 0.25 * s0
304
+ x0 = (m21 - m12) / jnp.where(s0 < eps, 1.0, s0)
305
+ y0 = (m02 - m20) / jnp.where(s0 < eps, 1.0, s0)
306
+ z0 = (m10 - m01) / jnp.where(s0 < eps, 1.0, s0)
307
+
308
+ # Case 1: m00 is the largest diagonal term
309
+ s1 = jnp.sqrt(jnp.clip(1.0 + m00 - m11 - m22, a_min=0.0)) * 2.0 # S = 4 * qx
310
+ w1 = (m21 - m12) / jnp.where(s1 < eps, 1.0, s1)
311
+ x1 = 0.25 * s1
312
+ y1 = (m01 + m10) / jnp.where(s1 < eps, 1.0, s1)
313
+ z1 = (m02 + m20) / jnp.where(s1 < eps, 1.0, s1)
314
+
315
+ # Case 2: m11 is the largest diagonal term
316
+ s2 = jnp.sqrt(jnp.clip(1.0 + m11 - m00 - m22, a_min=0.0)) * 2.0 # S = 4 * qy
317
+ w2 = (m02 - m20) / jnp.where(s2 < eps, 1.0, s2)
318
+ x2 = (m01 + m10) / jnp.where(s2 < eps, 1.0, s2)
319
+ y2 = 0.25 * s2
320
+ z2 = (m12 + m21) / jnp.where(s2 < eps, 1.0, s2)
321
+
322
+ # Case 3: m22 is the largest diagonal term
323
+ s3 = jnp.sqrt(jnp.clip(1.0 + m22 - m00 - m11, a_min=0.0)) * 2.0 # S = 4 * qz
324
+ w3 = (m10 - m01) / jnp.where(s3 < eps, 1.0, s3)
325
+ x3 = (m02 + m20) / jnp.where(s3 < eps, 1.0, s3)
326
+ y3 = (m12 + m21) / jnp.where(s3 < eps, 1.0, s3)
327
+ z3 = 0.25 * s3
328
+
329
+ cond0 = trace > 0.0
330
+ cond1 = (m00 > m11) & (m00 > m22)
331
+ cond2 = m11 > m22
332
+
333
+ w = jnp.where(
334
+ cond0,
335
+ w0,
336
+ jnp.where(cond1, w1, jnp.where(cond2, w2, w3)),
337
+ )
338
+ x = jnp.where(
339
+ cond0,
340
+ x0,
341
+ jnp.where(cond1, x1, jnp.where(cond2, x2, x3)),
342
+ )
343
+ y = jnp.where(
344
+ cond0,
345
+ y0,
346
+ jnp.where(cond1, y1, jnp.where(cond2, y2, y3)),
347
+ )
348
+ z = jnp.where(
349
+ cond0,
350
+ z0,
351
+ jnp.where(cond1, z1, jnp.where(cond2, z2, z3)),
352
+ )
353
+
354
+ quat = jnp.stack([w, x, y, z], axis=-1)
355
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
356
+ return quat
xax/task/base.py CHANGED
@@ -82,6 +82,9 @@ class BaseTask(Generic[Config]):
82
82
  def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
83
83
  return state
84
84
 
85
+ def add_logger_handlers(self, logger: logging.Logger) -> None:
86
+ pass
87
+
85
88
  @functools.cached_property
86
89
  def task_class_name(self) -> str:
87
90
  return self.__class__.__name__
@@ -15,8 +15,9 @@ def run_single_process_training(
15
15
  *cfgs: RawConfigType,
16
16
  use_cli: bool | list[str] = True,
17
17
  ) -> None:
18
- configure_logging()
18
+ logger = configure_logging()
19
19
  task_obj = task.get_task(*cfgs, use_cli=use_cli)
20
+ task_obj.add_logger_handlers(logger)
20
21
  task_obj.run()
21
22
 
22
23
 
@@ -14,7 +14,7 @@ from xax.core.state import State
14
14
  from xax.nn.parallel import is_master
15
15
  from xax.task.base import BaseConfig, BaseTask
16
16
  from xax.utils.experiments import stage_environment
17
- from xax.utils.logging import LOG_STATUS
17
+ from xax.utils.logging import LOG_STATUS, RankFilter
18
18
  from xax.utils.text import show_info
19
19
 
20
20
  logger = logging.getLogger(__name__)
@@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
24
24
  @dataclass
25
25
  class ArtifactsConfig(BaseConfig):
26
26
  exp_dir: str | None = field(None, help="The fixed experiment directory")
27
+ log_to_file: bool = field(True, help="If set, add a file handler to the logger to write all logs to the exp dir")
27
28
 
28
29
 
29
30
  Config = TypeVar("Config", bound=ArtifactsConfig)
@@ -39,6 +40,14 @@ class ArtifactsMixin(BaseTask[Config]):
39
40
  self._exp_dir = None
40
41
  self._stage_dir = None
41
42
 
43
+ def add_logger_handlers(self, logger: logging.Logger) -> None:
44
+ super().add_logger_handlers(logger)
45
+ if is_master() and self.config.log_to_file:
46
+ file_handler = logging.FileHandler(self.exp_dir / "logs.txt")
47
+ file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
48
+ file_handler.addFilter(RankFilter(rank=0))
49
+ logger.addHandler(file_handler)
50
+
42
51
  @functools.cached_property
43
52
  def run_dir(self) -> Path:
44
53
  run_dir = get_run_dir()
@@ -10,6 +10,7 @@ import jax
10
10
 
11
11
  from xax.task.base import BaseConfig, BaseTask, RawConfigType
12
12
  from xax.task.launchers.base import BaseLauncher
13
+ from xax.task.launchers.cli import CliLauncher
13
14
 
14
15
 
15
16
  @jax.tree_util.register_dataclass
@@ -45,8 +46,6 @@ class RunnableMixin(BaseTask[Config], ABC):
45
46
  use_cli: bool | list[str] = True,
46
47
  ) -> None:
47
48
  if launcher is None:
48
- from xax.task.launchers.cli import CliLauncher
49
-
50
49
  launcher = CliLauncher()
51
50
  launcher.launch(cls, *cfgs, use_cli=use_cli)
52
51
 
xax/utils/logging.py CHANGED
@@ -146,7 +146,7 @@ def configure_logging(
146
146
  rank: int | None = None,
147
147
  world_size: int | None = None,
148
148
  debug: bool | None = None,
149
- ) -> None:
149
+ ) -> logging.Logger:
150
150
  """Instantiates logging.
151
151
 
152
152
  This captures logs and reroutes them to the Toasts module, which is
@@ -186,6 +186,8 @@ def configure_logging(
186
186
  logging.getLogger("PIL").setLevel(logging.WARNING)
187
187
  logging.getLogger("torch").setLevel(logging.WARNING)
188
188
 
189
+ return root_logger
190
+
189
191
 
190
192
  def get_unused_port(default: int | None = None) -> int:
191
193
  """Returns an unused port number on the local machine.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=Ku51ccrDjxWuSCgmHHVnTXhloyA_d9A1VlEEbO0Ycjg,15713
1
+ xax/__init__.py,sha256=ffVd9_qSVuEAIPn6eK_6N8qEfDALRJigArHZQGy6y1o,15819
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -6,25 +6,25 @@ xax/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/cli/edit_config.py,sha256=LQUIlOS6hvPZyVEaMme3FP-62M0BKQPYavCwVDWuBLw,2600
7
7
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
9
- xax/core/state.py,sha256=F9Tj3FfCw8zFKaDEoEGiThZE2ntYEtzNjnBX3pQ1g60,3826
9
+ xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
10
10
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  xax/nn/attention.py,sha256=0essK90OO3x9FxnUqU0DhufwXKRMN41zMtRCki5iKzQ,24742
12
12
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
13
13
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
14
- xax/nn/geom.py,sha256=7_cUeoJIfCS_YbQ6O7zElnpYsAn4XrVm4m7PaT9UcYw,8005
14
+ xax/nn/geom.py,sha256=6rBQrZRX1miG08VG-s8phPjA6MEFxUAfQVPt5F0RQQI,10645
15
15
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
16
16
  xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
17
17
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
18
18
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
19
19
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- xax/task/base.py,sha256=TYANmjNcce4_V5ZSYLnE91PXRn7Nn0nT7hN8plW_Au0,8117
20
+ xax/task/base.py,sha256=i6FRJ75aqlekWkzJNRWDUEX7P514pUjLVuxjhX1GBgw,8198
21
21
  xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
22
22
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
23
23
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
24
24
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
26
26
  xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
27
- xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
27
+ xax/task/launchers/single_process.py,sha256=wdEUT-B-FE9aemmt1tB_rKKRNy60aiDhslsy2i-ojWo,896
28
28
  xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
29
  xax/task/loggers/callback.py,sha256=zQuV1xCvz47Q3UQqP1D5mBhbVzptvmPR_7hX25vqSk0,1667
30
30
  xax/task/loggers/json.py,sha256=6A5wL7kspsXnpPhI_vu0scgd2Z2-WLhw4gbBFm7eZMM,4377
@@ -32,7 +32,7 @@ xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,140
32
32
  xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
33
33
  xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
34
34
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
35
- xax/task/mixins/artifacts.py,sha256=IBPAQMCGd7PQiZHfSjLakPW5j7cNuL6AsW6QkVSc02E,3277
35
+ xax/task/mixins/artifacts.py,sha256=R-y3p7__zJHlHDqwDVAZysg2ZmebCJbqAx_xGT2Xpd0,3857
36
36
  xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
37
37
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
38
38
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
@@ -40,7 +40,7 @@ xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-go
40
40
  xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
41
41
  xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
42
42
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
43
- xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
+ xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
44
44
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
45
45
  xax/task/mixins/train.py,sha256=TZatz5QwTfrNhQTiO2IqrmQY9P4Lay6FAD2VsQpWa54,33245
46
46
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -48,7 +48,7 @@ xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
48
48
  xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
49
49
  xax/utils/jax.py,sha256=6cP95-rcjkRt1fefkZWJQhJhH0uUYWJB3w4NP1-aDp0,10136
50
50
  xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
51
- xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
51
+ xax/utils/logging.py,sha256=Kkyma_LJXqrN2HTQ214gRP_9ih3_bKk115MWC60lQWM,6656
52
52
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
53
53
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
54
54
  xax/utils/pytree.py,sha256=rVY2kKa637xfX3Oue6OP9ScwmDyxJ_CeHkUpZZtmN04,9231
@@ -59,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
59
59
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
60
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
61
61
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
62
- xax-0.3.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
- xax-0.3.1.dist-info/METADATA,sha256=xAwDUWNoga60_KZ_nPEXrdAuvlak-zWAXk3lE-XvLeI,1246
64
- xax-0.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
- xax-0.3.1.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
- xax-0.3.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
- xax-0.3.1.dist-info/RECORD,,
62
+ xax-0.3.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
+ xax-0.3.3.dist-info/METADATA,sha256=mjIzoFZDSR3V1-2LHbvup6wDVa4vLiqbqiNWLsKCXY8,1246
64
+ xax-0.3.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
+ xax-0.3.3.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
+ xax-0.3.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
+ xax-0.3.3.dist-info/RECORD,,
File without changes