xax 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.13"
15
+ __version__ = "0.1.15"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -40,6 +40,7 @@ __all__ = [
40
40
  "load_eqx_mlp",
41
41
  "make_eqx_mlp",
42
42
  "save_eqx",
43
+ "cubic_bezier_interpolation",
43
44
  "euler_to_quat",
44
45
  "get_projected_gravity_vector_from_quat",
45
46
  "quat_to_euler",
@@ -77,6 +78,7 @@ __all__ = [
77
78
  "collate_non_null",
78
79
  "breakpoint_if_nan",
79
80
  "get_named_leaves",
81
+ "log_if_nan",
80
82
  "BaseFileDownloader",
81
83
  "ContextTimer",
82
84
  "CumulativeTimer",
@@ -200,6 +202,7 @@ NAME_MAP: dict[str, str] = {
200
202
  "load_eqx_mlp": "nn.equinox",
201
203
  "make_eqx_mlp": "nn.equinox",
202
204
  "save_eqx": "nn.equinox",
205
+ "cubic_bezier_interpolation": "nn.geom",
203
206
  "euler_to_quat": "nn.geom",
204
207
  "get_projected_gravity_vector_from_quat": "nn.geom",
205
208
  "quat_to_euler": "nn.geom",
@@ -237,6 +240,7 @@ NAME_MAP: dict[str, str] = {
237
240
  "collate_non_null": "utils.data.collate",
238
241
  "breakpoint_if_nan": "utils.debugging",
239
242
  "get_named_leaves": "utils.debugging",
243
+ "log_if_nan": "utils.debugging",
240
244
  "BaseFileDownloader": "utils.experiments",
241
245
  "ContextTimer": "utils.experiments",
242
246
  "CumulativeTimer": "utils.experiments",
@@ -361,6 +365,7 @@ if IMPORT_ALL or TYPE_CHECKING:
361
365
  save_eqx,
362
366
  )
363
367
  from xax.nn.geom import (
368
+ cubic_bezier_interpolation,
364
369
  euler_to_quat,
365
370
  get_projected_gravity_vector_from_quat,
366
371
  quat_to_euler,
@@ -388,7 +393,7 @@ if IMPORT_ALL or TYPE_CHECKING:
388
393
  from xax.task.script import Script, ScriptConfig
389
394
  from xax.task.task import Config, Task
390
395
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
391
- from xax.utils.debugging import breakpoint_if_nan, get_named_leaves
396
+ from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
392
397
  from xax.utils.experiments import (
393
398
  BaseFileDownloader,
394
399
  ContextTimer,
xax/nn/geom.py CHANGED
@@ -1,10 +1,10 @@
1
1
  """Defines geometry functions."""
2
2
 
3
- import jax
4
3
  from jax import numpy as jnp
4
+ from jaxtyping import Array
5
5
 
6
6
 
7
- def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
7
+ def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
8
8
  """Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
9
9
 
10
10
  Args:
@@ -40,7 +40,7 @@ def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
40
40
  return jnp.concatenate([roll, pitch, yaw], axis=-1)
41
41
 
42
42
 
43
- def euler_to_quat(euler_3: jax.Array) -> jax.Array:
43
+ def euler_to_quat(euler_3: Array) -> Array:
44
44
  """Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
45
45
 
46
46
  Args:
@@ -75,7 +75,7 @@ def euler_to_quat(euler_3: jax.Array) -> jax.Array:
75
75
  return quat
76
76
 
77
77
 
78
- def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -> jax.Array:
78
+ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Array:
79
79
  """Calculates the gravity vector projected onto the local frame given a quaternion orientation.
80
80
 
81
81
  Args:
@@ -101,7 +101,7 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
101
101
  return jnp.concatenate([gx, gy, -gz], axis=-1)
102
102
 
103
103
 
104
- def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
104
+ def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
105
105
  """Rotates a vector by a quaternion.
106
106
 
107
107
  Args:
@@ -156,3 +156,24 @@ def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6)
156
156
  )
157
157
 
158
158
  return jnp.concatenate([xx, yy, zz], axis=-1)
159
+
160
+
161
+ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
162
+ """Cubic bezier interpolation.
163
+
164
+ This is a cubic bezier curve that starts at y_start and ends at y_end,
165
+ and is controlled by the parameter x. The curve is defined by the following formula:
166
+
167
+ y(x) = y_start + (y_end - y_start) * (x**3 + 3 * (x**2 * (1 - x)))
168
+
169
+ Args:
170
+ y_start: The start value, shape (*).
171
+ y_end: The end value, shape (*).
172
+ x: The interpolation parameter, shape (*).
173
+
174
+ Returns:
175
+ The interpolated value, shape (*).
176
+ """
177
+ y_diff = y_end - y_start
178
+ bezier = x**3 + 3 * (x**2 * (1 - x))
179
+ return y_start + y_diff * bezier
xax/task/base.py CHANGED
@@ -82,9 +82,6 @@ class BaseTask(Generic[Config]):
82
82
  def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
83
83
  return state
84
84
 
85
- def on_before_checkpoint_load(self, ckpt_path: Path) -> None:
86
- pass
87
-
88
85
  @functools.cached_property
89
86
  def task_class_name(self) -> str:
90
87
  return self.__class__.__name__
@@ -98,19 +98,39 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
98
98
  ) -> tuple[PyTree, State, DictConfig]: ...
99
99
 
100
100
  @overload
101
- def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
101
+ def load_checkpoint(
102
+ self,
103
+ path: Path,
104
+ part: Literal["model"],
105
+ ) -> PyTree: ...
102
106
 
103
107
  @overload
104
- def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
108
+ def load_checkpoint(
109
+ self,
110
+ path: Path,
111
+ part: Literal["opt"],
112
+ ) -> optax.GradientTransformation: ...
105
113
 
106
114
  @overload
107
- def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
115
+ def load_checkpoint(
116
+ self,
117
+ path: Path,
118
+ part: Literal["opt_state"],
119
+ ) -> optax.OptState: ...
108
120
 
109
121
  @overload
110
- def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
122
+ def load_checkpoint(
123
+ self,
124
+ path: Path,
125
+ part: Literal["state"],
126
+ ) -> State: ...
111
127
 
112
128
  @overload
113
- def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
129
+ def load_checkpoint(
130
+ self,
131
+ path: Path,
132
+ part: Literal["config"],
133
+ ) -> DictConfig: ...
114
134
 
115
135
  def load_checkpoint(
116
136
  self,
@@ -125,9 +145,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
125
145
  | State
126
146
  | DictConfig
127
147
  ):
128
- # Calls the base callback.
129
- self.on_before_checkpoint_load(path)
130
-
131
148
  with tarfile.open(path, "r:gz") as tar:
132
149
 
133
150
  def get_model() -> PyTree:
@@ -32,6 +32,10 @@ def get_cache_dir() -> str | None:
32
32
  @dataclass
33
33
  class CompileOptions:
34
34
  # JAX compilation options
35
+ debug_nans: bool = field(
36
+ value=False,
37
+ help="If True, breaks on NaNs",
38
+ )
35
39
  disable_jit: bool = field(
36
40
  value=False,
37
41
  help="If True, disables JIT compilation",
@@ -89,6 +93,10 @@ class CompileMixin(BaseTask[Config], Generic[Config]):
89
93
  cc = self.config.compile
90
94
 
91
95
  # Set basic compilation flags
96
+ if cc.debug_nans:
97
+ logger.info("Enabling NaNs debugging")
98
+ jax.config.update("jax_debug_nans", True)
99
+
92
100
  if cc.disable_jit:
93
101
  logger.info("Disabling JIT compilation")
94
102
  jax.config.update("jax_disable_jit", True)
xax/task/mixins/train.py CHANGED
@@ -50,8 +50,7 @@ from xax.utils.experiments import (
50
50
  TrainingFinishedError,
51
51
  diff_configs,
52
52
  get_diff_string,
53
- get_git_state,
54
- get_packages_with_versions,
53
+ get_state_file_string,
55
54
  get_training_code,
56
55
  )
57
56
  from xax.utils.jax import jit as xax_jit
@@ -534,9 +533,8 @@ class TrainMixin(
534
533
  logger.log(LOG_STATUS, self.task_path)
535
534
  logger.log(LOG_STATUS, self.task_name)
536
535
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
537
- self.logger.log_file("git_state.txt", get_git_state(self))
538
- self.logger.log_file("packages.txt", get_packages_with_versions())
539
- self.logger.log_file("training_code.txt", get_training_code(self))
536
+ self.logger.log_file("state.txt", get_state_file_string(self))
537
+ self.logger.log_file("training_code.py", get_training_code(self))
540
538
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
541
539
 
542
540
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
xax/utils/debugging.py CHANGED
@@ -53,3 +53,7 @@ def get_named_leaves(
53
53
 
54
54
  def breakpoint_if_nan(x: Array) -> None:
55
55
  jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.breakpoint(), lambda: None)
56
+
57
+
58
+ def log_if_nan(x: Array, loc: str) -> None:
59
+ jax.lax.cond(jnp.any(jnp.isnan(x)), lambda: jax.debug.print("=== NaNs: {loc} ===", loc=loc), lambda: None)
xax/utils/experiments.py CHANGED
@@ -479,6 +479,20 @@ def get_packages_with_versions() -> str:
479
479
  return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
480
480
 
481
481
 
482
+ def get_command_line_string() -> str:
483
+ return " ".join(sys.argv)
484
+
485
+
486
+ def get_state_file_string(obj: object) -> str:
487
+ return "\n\n".join(
488
+ [
489
+ f"=== Command Line ===\n\n{get_command_line_string()}",
490
+ f"=== Git State ===\n\n{get_git_state(obj)}",
491
+ f"=== Packages ===\n\n{get_packages_with_versions()}",
492
+ ]
493
+ )
494
+
495
+
482
496
  def get_training_code(obj: object) -> str:
483
497
  """Gets the text from the file containing the provided object.
484
498
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.13
3
+ Version: 0.1.15
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=zBQiNKXLrE-9UPy1q-XBFwPNAKTAxr6wpAwYwaTVggs,13922
1
+ xax/__init__.py,sha256=bV2mTcuiVaVNvwgbDgg7dKDkMeuyA0mqF0muU5KZHeg,14104
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -10,13 +10,13 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
11
  xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
- xax/nn/geom.py,sha256=Bj9Z4Y-uoNQuaA_eB_MyG7yImZLuOq8KCLUj1l3daoc,4545
13
+ xax/nn/geom.py,sha256=PN0Ndn575aVtsSfxi67RghHB7luRkqtpS7bPbT1LpLE,5201
14
14
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
15
  xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
16
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
17
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
19
+ xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
20
20
  xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
21
21
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
22
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
@@ -32,8 +32,8 @@ xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,67
32
32
  xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
- xax/task/mixins/checkpointing.py,sha256=a6tVyISsDIz68rrhb1rAh3rjQlqkDVJCmSBmETQrnRM,8480
36
- xax/task/mixins/compile.py,sha256=8jEdlGs-a14N_CwZA3Rxe461MT83dyIDr3Z56VkjviQ,3693
35
+ xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
36
+ xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
37
  xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
38
38
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
39
39
  xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
@@ -41,10 +41,10 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
44
+ xax/task/mixins/train.py,sha256=1hmUx1HIL8HKfwOnupS3Knsw1CiK2YCbIQnUTYyDEms,26157
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
- xax/utils/debugging.py,sha256=0DU41DvYp3SZ9tMrM7sSFpfhC7dieMYR7eRlGNAFrdM,1783
47
- xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
46
+ xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
+ xax/utils/experiments.py,sha256=X6MESZ3z_Z0DLH6NQucuPzibuOc6rZmlf5UZt4in458,29591
48
48
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
49
49
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
50
50
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.1.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.1.13.dist-info/METADATA,sha256=P3N5hJMZtXPs199OHPN_cBp57S9zlOGB2B1TRPaczuI,1878
63
- xax-0.1.13.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.1.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.1.13.dist-info/RECORD,,
61
+ xax-0.1.15.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.1.15.dist-info/METADATA,sha256=i5thFSTL1Zx03UpnCj7f71rxSgs0P3L6ZDd6vYEtM7U,1878
63
+ xax-0.1.15.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.1.15.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.1.15.dist-info/RECORD,,
File without changes