xax 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.0"
15
+ __version__ = "0.2.1"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -44,6 +44,7 @@ __all__ = [
44
44
  "euler_to_quat",
45
45
  "get_projected_gravity_vector_from_quat",
46
46
  "quat_to_euler",
47
+ "quat_to_rotmat",
47
48
  "rotate_vector_by_quat",
48
49
  "cross_entropy",
49
50
  "cast_norm_type",
@@ -206,6 +207,7 @@ NAME_MAP: dict[str, str] = {
206
207
  "euler_to_quat": "nn.geom",
207
208
  "get_projected_gravity_vector_from_quat": "nn.geom",
208
209
  "quat_to_euler": "nn.geom",
210
+ "quat_to_rotmat": "nn.geom",
209
211
  "rotate_vector_by_quat": "nn.geom",
210
212
  "cross_entropy": "nn.losses",
211
213
  "cast_norm_type": "nn.norm",
@@ -369,6 +371,7 @@ if IMPORT_ALL or TYPE_CHECKING:
369
371
  euler_to_quat,
370
372
  get_projected_gravity_vector_from_quat,
371
373
  quat_to_euler,
374
+ quat_to_rotmat,
372
375
  rotate_vector_by_quat,
373
376
  )
374
377
  from xax.nn.losses import cross_entropy
xax/nn/geom.py CHANGED
@@ -177,3 +177,37 @@ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
177
177
  y_diff = y_end - y_start
178
178
  bezier = x**3 + 3 * (x**2 * (1 - x))
179
179
  return y_start + y_diff * bezier
180
+
181
+
182
+ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
183
+ """Converts a quaternion to a rotation matrix.
184
+
185
+ Args:
186
+ quat: The quaternion to convert, shape (*, 4).
187
+ eps: A small epsilon value to avoid division by zero.
188
+
189
+ Returns:
190
+ The rotation matrix, shape (*, 3, 3).
191
+ """
192
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
193
+ w, x, y, z = jnp.split(quat, 4, axis=-1)
194
+
195
+ xx = 1 - 2 * (y * y + z * z)
196
+ xy = 2 * (x * y - z * w)
197
+ xz = 2 * (x * z + y * w)
198
+ yx = 2 * (x * y + z * w)
199
+ yy = 1 - 2 * (x * x + z * z)
200
+ yz = 2 * (y * z - x * w)
201
+ zx = 2 * (x * z - y * w)
202
+ zy = 2 * (y * z + x * w)
203
+ zz = 1 - 2 * (x * x + y * y)
204
+
205
+ # Corrected stacking: row-major order
206
+ return jnp.concatenate(
207
+ [
208
+ jnp.concatenate([xx, xy, xz], axis=-1)[..., None, :],
209
+ jnp.concatenate([yx, yy, yz], axis=-1)[..., None, :],
210
+ jnp.concatenate([zx, zy, zz], axis=-1)[..., None, :],
211
+ ],
212
+ axis=-2,
213
+ )
@@ -63,10 +63,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
63
63
 
64
64
  def get_init_ckpt_path(self) -> Path | None:
65
65
  if self._exp_dir is not None:
66
- ckpt_path = self.get_ckpt_path()
67
- if not ckpt_path.exists():
68
- logger.warning("No checkpoint found in experiment directory: %s", ckpt_path)
69
- else:
66
+ if (ckpt_path := self.get_ckpt_path()).exists():
70
67
  return ckpt_path
71
68
  if self.config.load_from_ckpt_path is not None:
72
69
  ckpt_path = Path(self.config.load_from_ckpt_path)
@@ -86,7 +83,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
86
83
  return False
87
84
 
88
85
  @overload
89
- def load_checkpoint(
86
+ def load_ckpt_with_template(
90
87
  self,
91
88
  path: Path,
92
89
  *,
@@ -97,7 +94,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
97
94
  ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
98
95
 
99
96
  @overload
100
- def load_checkpoint(
97
+ def load_ckpt_with_template(
101
98
  self,
102
99
  path: Path,
103
100
  *,
@@ -106,7 +103,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
106
103
  ) -> tuple[PyTree, State, Config]: ...
107
104
 
108
105
  @overload
109
- def load_checkpoint(
106
+ def load_ckpt_with_template(
110
107
  self,
111
108
  path: Path,
112
109
  *,
@@ -115,7 +112,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
115
112
  ) -> PyTree: ...
116
113
 
117
114
  @overload
118
- def load_checkpoint(
115
+ def load_ckpt_with_template(
119
116
  self,
120
117
  path: Path,
121
118
  *,
@@ -124,7 +121,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
124
121
  ) -> optax.GradientTransformation: ...
125
122
 
126
123
  @overload
127
- def load_checkpoint(
124
+ def load_ckpt_with_template(
128
125
  self,
129
126
  path: Path,
130
127
  *,
@@ -133,7 +130,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
133
130
  ) -> optax.OptState: ...
134
131
 
135
132
  @overload
136
- def load_checkpoint(
133
+ def load_ckpt_with_template(
137
134
  self,
138
135
  path: Path,
139
136
  *,
@@ -141,14 +138,14 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
141
138
  ) -> State: ...
142
139
 
143
140
  @overload
144
- def load_checkpoint(
141
+ def load_ckpt_with_template(
145
142
  self,
146
143
  path: Path,
147
144
  *,
148
145
  part: Literal["config"],
149
146
  ) -> Config: ...
150
147
 
151
- def load_checkpoint(
148
+ def load_ckpt_with_template(
152
149
  self,
153
150
  path: Path,
154
151
  *,
xax/task/mixins/train.py CHANGED
@@ -12,6 +12,7 @@ import time
12
12
  import traceback
13
13
  from abc import ABC, abstractmethod
14
14
  from dataclasses import asdict, dataclass, is_dataclass
15
+ from pathlib import Path
15
16
  from threading import Thread
16
17
  from typing import (
17
18
  Any,
@@ -39,7 +40,7 @@ from xax.core.state import Phase, State
39
40
  from xax.nn.functions import set_random_seed
40
41
  from xax.nn.parallel import is_master
41
42
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
42
- from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
43
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
43
44
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
44
45
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
45
46
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -54,7 +55,7 @@ from xax.utils.experiments import (
54
55
  get_training_code,
55
56
  )
56
57
  from xax.utils.jax import jit as xax_jit
57
- from xax.utils.logging import LOG_STATUS
58
+ from xax.utils.logging import LOG_PING, LOG_STATUS
58
59
  from xax.utils.text import highlight_exception_message, show_info
59
60
  from xax.utils.types.frozen_dict import FrozenDict
60
61
 
@@ -340,12 +341,7 @@ class TrainMixin(
340
341
 
341
342
  if init_ckpt_path is not None:
342
343
  logger.info("Loading checkpoint from %s", init_ckpt_path)
343
- model_spec = eqx.filter_eval_shape(self.get_model, key)
344
- model, state, config = self.load_checkpoint(
345
- init_ckpt_path,
346
- part="model_state_config",
347
- model_template=model_spec,
348
- )
344
+ model, state, config = self.load_ckpt(init_ckpt_path, part="model_state_config")
349
345
  config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
350
346
  if config_diff:
351
347
  logger.warning("Loaded config differs from current config:\n%s", config_diff)
@@ -353,17 +349,11 @@ class TrainMixin(
353
349
  if not load_optimizer:
354
350
  return model, state
355
351
 
356
- # Loads the optimizer.
357
- optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
358
- optimizer = self.load_checkpoint(init_ckpt_path, part="opt", optimizer_template=optimizer_spec)
359
-
360
- # Loads the optimizer state.
361
- opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
362
- opt_state = self.load_checkpoint(init_ckpt_path, part="opt_state", opt_state_template=opt_state_spec)
363
-
352
+ optimizer = self.load_ckpt(init_ckpt_path, part="opt")
353
+ opt_state = self.load_ckpt(init_ckpt_path, part="opt_state", model=model, optimizer=optimizer)
364
354
  return model, optimizer, opt_state, state
365
355
 
366
- logger.info("No checkpoint found. Initializing a new model.")
356
+ logger.info("Starting a new training run")
367
357
  model = self.get_model(key)
368
358
  state = State.init_state()
369
359
 
@@ -375,6 +365,131 @@ class TrainMixin(
375
365
 
376
366
  return model, optimizer, opt_state, state
377
367
 
368
+ @overload
369
+ def load_ckpt(
370
+ self,
371
+ path: Path,
372
+ *,
373
+ part: Literal["all"],
374
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
375
+
376
+ @overload
377
+ def load_ckpt(
378
+ self,
379
+ path: Path,
380
+ *,
381
+ part: Literal["model_state_config"],
382
+ ) -> tuple[PyTree, State, Config]: ...
383
+
384
+ @overload
385
+ def load_ckpt(
386
+ self,
387
+ path: Path,
388
+ *,
389
+ part: Literal["model"],
390
+ ) -> PyTree: ...
391
+
392
+ @overload
393
+ def load_ckpt(
394
+ self,
395
+ path: Path,
396
+ *,
397
+ part: Literal["opt"],
398
+ ) -> optax.GradientTransformation: ...
399
+
400
+ @overload
401
+ def load_ckpt(
402
+ self,
403
+ path: Path,
404
+ *,
405
+ part: Literal["opt_state"],
406
+ model: PyTree | None = None,
407
+ optimizer: optax.GradientTransformation | None = None,
408
+ ) -> optax.OptState: ...
409
+
410
+ @overload
411
+ def load_ckpt(
412
+ self,
413
+ path: Path,
414
+ *,
415
+ part: Literal["state"],
416
+ ) -> State: ...
417
+
418
+ @overload
419
+ def load_ckpt(
420
+ self,
421
+ path: Path,
422
+ *,
423
+ part: Literal["config"],
424
+ ) -> Config: ...
425
+
426
+ def load_ckpt(
427
+ self,
428
+ path: str | Path,
429
+ *,
430
+ part: CheckpointPart = "all",
431
+ model: PyTree | None = None,
432
+ optimizer: optax.GradientTransformation | None = None,
433
+ ) -> (
434
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
435
+ | tuple[PyTree, State, Config]
436
+ | PyTree
437
+ | optax.GradientTransformation
438
+ | optax.OptState
439
+ | State
440
+ | Config
441
+ ):
442
+ path = Path(path)
443
+
444
+ # This key isn't used for anything, it's just a required argument.
445
+ key = jax.random.PRNGKey(0)
446
+
447
+ match part:
448
+ case "model_state_config":
449
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
450
+ return self.load_ckpt_with_template(path, part="model_state_config", model_template=model_spec)
451
+
452
+ case "model":
453
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
454
+ return self.load_ckpt_with_template(path, part="model", model_template=model_spec)
455
+
456
+ case "config":
457
+ return self.load_ckpt_with_template(path, part="config")
458
+
459
+ case "opt":
460
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
461
+ return self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
462
+
463
+ case "opt_state":
464
+ if model is None:
465
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
466
+ model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
467
+ if optimizer is None:
468
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
469
+ optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
470
+ opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
471
+ return self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
472
+
473
+ case "state":
474
+ return self.load_ckpt_with_template(path, part="state")
475
+
476
+ case "config":
477
+ return self.load_ckpt_with_template(path, part="config")
478
+
479
+ case "all":
480
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
481
+ model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
482
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
483
+ optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
484
+ opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
485
+ opt_state = self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
486
+ state = self.load_ckpt_with_template(path, part="state")
487
+ config = self.load_ckpt_with_template(path, part="config")
488
+ return model, optimizer, opt_state, state, config
489
+
490
+ case _:
491
+ raise ValueError(f"Unknown checkpoint part: {part}")
492
+
378
493
  def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
379
494
  """Gets the output from the model.
380
495
 
@@ -529,8 +644,7 @@ class TrainMixin(
529
644
  self._last_printed_remaining_time = state.elapsed_time_s
530
645
  remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
531
646
  termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
532
- # logger.info("Estimated finish time: %s", termination_time)
533
- jax.debug.print("Estimated finish time: {}", termination_time)
647
+ logger.log(LOG_PING, "Estimated finish time: %s", termination_time)
534
648
 
535
649
  def get_remaining_percent(self, state: State) -> float | None:
536
650
  if self.config.max_steps is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=CO9UZlYsYsDL2B6z-Id0Fv0ZSD5uwUZ3eZ6zwwqtJhU,14103
1
+ xax/__init__.py,sha256=kd-88OQGnuHb91PXwroAfLb0bMfbe37fXqpECRrjhoU,14182
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -10,7 +10,7 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
11
  xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
- xax/nn/geom.py,sha256=PN0Ndn575aVtsSfxi67RghHB7luRkqtpS7bPbT1LpLE,5201
13
+ xax/nn/geom.py,sha256=rImNlkHWeoNcY7f84nknizJ6uzsrMhbAtKeb2xAWxNY,6215
14
14
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
15
  xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
16
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
@@ -32,7 +32,7 @@ xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,65
32
32
  xax/task/loggers/tensorboard.py,sha256=KOL9l60tLctX-VAdNwe49H48SAJeGxph3sflJpojA-4,8337
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
- xax/task/mixins/checkpointing.py,sha256=JHBOdcgmJvhyXldPF5pHRmyPUN9SHcxxngsC1ap4b1E,11468
35
+ xax/task/mixins/checkpointing.py,sha256=2nJgqFcV-D8W-4j8TR3PvVh1g5hQUOo-_quKO-XlE4U,11398
36
36
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
37
  xax/task/mixins/cpu_stats.py,sha256=vAjEc3HpPnl56m7vshYX0dXAHJrB98DzVdsYSRqQllc,9371
38
38
  xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=t8Qyw40ahuJW0SPVgFLljqYbbSc1M_WLop87iwYE41Q,27064
44
+ xax/task/mixins/train.py,sha256=v9oi9tNsNBYo-Ne_98nCG9qHX6sxvymHjsRDnL6GL-U,30871
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
47
  xax/utils/experiments.py,sha256=Hzl46_9IH5_9cKzxit-FyVUWBH-_lBs00ZciuIdnWO8,29811
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.0.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.0.dist-info/METADATA,sha256=FyMDy4yB_KQF_IdCMMe_10VWpIEE5g6qEIZuXx-pLgU,1882
63
- xax-0.2.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.2.0.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.0.dist-info/RECORD,,
61
+ xax-0.2.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.1.dist-info/METADATA,sha256=2pOZLKMIcLoQTM-tRqRvVkF57PZyMoALM87UI5B4dtk,1882
63
+ xax-0.2.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.2.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.1.dist-info/RECORD,,
File without changes