xax 0.1.14__tar.gz → 0.1.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.1.14/xax.egg-info → xax-0.1.15}/PKG-INFO +1 -1
  2. {xax-0.1.14 → xax-0.1.15}/xax/__init__.py +4 -1
  3. {xax-0.1.14 → xax-0.1.15}/xax/nn/geom.py +26 -5
  4. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/train.py +3 -5
  5. {xax-0.1.14 → xax-0.1.15}/xax/utils/experiments.py +14 -0
  6. {xax-0.1.14 → xax-0.1.15/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.1.14 → xax-0.1.15}/LICENSE +0 -0
  8. {xax-0.1.14 → xax-0.1.15}/MANIFEST.in +0 -0
  9. {xax-0.1.14 → xax-0.1.15}/README.md +0 -0
  10. {xax-0.1.14 → xax-0.1.15}/pyproject.toml +0 -0
  11. {xax-0.1.14 → xax-0.1.15}/setup.cfg +0 -0
  12. {xax-0.1.14 → xax-0.1.15}/setup.py +0 -0
  13. {xax-0.1.14 → xax-0.1.15}/xax/core/__init__.py +0 -0
  14. {xax-0.1.14 → xax-0.1.15}/xax/core/conf.py +0 -0
  15. {xax-0.1.14 → xax-0.1.15}/xax/core/state.py +0 -0
  16. {xax-0.1.14 → xax-0.1.15}/xax/nn/__init__.py +0 -0
  17. {xax-0.1.14 → xax-0.1.15}/xax/nn/embeddings.py +0 -0
  18. {xax-0.1.14 → xax-0.1.15}/xax/nn/equinox.py +0 -0
  19. {xax-0.1.14 → xax-0.1.15}/xax/nn/export.py +0 -0
  20. {xax-0.1.14 → xax-0.1.15}/xax/nn/functions.py +0 -0
  21. {xax-0.1.14 → xax-0.1.15}/xax/nn/losses.py +0 -0
  22. {xax-0.1.14 → xax-0.1.15}/xax/nn/norm.py +0 -0
  23. {xax-0.1.14 → xax-0.1.15}/xax/nn/parallel.py +0 -0
  24. {xax-0.1.14 → xax-0.1.15}/xax/nn/ssm.py +0 -0
  25. {xax-0.1.14 → xax-0.1.15}/xax/py.typed +0 -0
  26. {xax-0.1.14 → xax-0.1.15}/xax/requirements-dev.txt +0 -0
  27. {xax-0.1.14 → xax-0.1.15}/xax/requirements.txt +0 -0
  28. {xax-0.1.14 → xax-0.1.15}/xax/task/__init__.py +0 -0
  29. {xax-0.1.14 → xax-0.1.15}/xax/task/base.py +0 -0
  30. {xax-0.1.14 → xax-0.1.15}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.1.14 → xax-0.1.15}/xax/task/launchers/base.py +0 -0
  32. {xax-0.1.14 → xax-0.1.15}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.1.14 → xax-0.1.15}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.1.14 → xax-0.1.15}/xax/task/logger.py +0 -0
  35. {xax-0.1.14 → xax-0.1.15}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.1.14 → xax-0.1.15}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.1.14 → xax-0.1.15}/xax/task/loggers/json.py +0 -0
  38. {xax-0.1.14 → xax-0.1.15}/xax/task/loggers/state.py +0 -0
  39. {xax-0.1.14 → xax-0.1.15}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.1.14 → xax-0.1.15}/xax/task/loggers/tensorboard.py +0 -0
  41. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/__init__.py +0 -0
  42. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/artifacts.py +0 -0
  43. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/checkpointing.py +0 -0
  44. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/compile.py +0 -0
  45. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/cpu_stats.py +0 -0
  46. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/gpu_stats.py +0 -0
  48. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/process.py +0 -0
  50. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/runnable.py +0 -0
  51. {xax-0.1.14 → xax-0.1.15}/xax/task/mixins/step_wrapper.py +0 -0
  52. {xax-0.1.14 → xax-0.1.15}/xax/task/script.py +0 -0
  53. {xax-0.1.14 → xax-0.1.15}/xax/task/task.py +0 -0
  54. {xax-0.1.14 → xax-0.1.15}/xax/utils/__init__.py +0 -0
  55. {xax-0.1.14 → xax-0.1.15}/xax/utils/data/__init__.py +0 -0
  56. {xax-0.1.14 → xax-0.1.15}/xax/utils/data/collate.py +0 -0
  57. {xax-0.1.14 → xax-0.1.15}/xax/utils/debugging.py +0 -0
  58. {xax-0.1.14 → xax-0.1.15}/xax/utils/jax.py +0 -0
  59. {xax-0.1.14 → xax-0.1.15}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.1.14 → xax-0.1.15}/xax/utils/logging.py +0 -0
  61. {xax-0.1.14 → xax-0.1.15}/xax/utils/numpy.py +0 -0
  62. {xax-0.1.14 → xax-0.1.15}/xax/utils/profile.py +0 -0
  63. {xax-0.1.14 → xax-0.1.15}/xax/utils/pytree.py +0 -0
  64. {xax-0.1.14 → xax-0.1.15}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.1.14 → xax-0.1.15}/xax/utils/text.py +0 -0
  66. {xax-0.1.14 → xax-0.1.15}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.1.14 → xax-0.1.15}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.1.14 → xax-0.1.15}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.1.14 → xax-0.1.15}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.1.14 → xax-0.1.15}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.1.14 → xax-0.1.15}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.1.14 → xax-0.1.15}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.14
3
+ Version: 0.1.15
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.14"
15
+ __version__ = "0.1.15"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -40,6 +40,7 @@ __all__ = [
40
40
  "load_eqx_mlp",
41
41
  "make_eqx_mlp",
42
42
  "save_eqx",
43
+ "cubic_bezier_interpolation",
43
44
  "euler_to_quat",
44
45
  "get_projected_gravity_vector_from_quat",
45
46
  "quat_to_euler",
@@ -201,6 +202,7 @@ NAME_MAP: dict[str, str] = {
201
202
  "load_eqx_mlp": "nn.equinox",
202
203
  "make_eqx_mlp": "nn.equinox",
203
204
  "save_eqx": "nn.equinox",
205
+ "cubic_bezier_interpolation": "nn.geom",
204
206
  "euler_to_quat": "nn.geom",
205
207
  "get_projected_gravity_vector_from_quat": "nn.geom",
206
208
  "quat_to_euler": "nn.geom",
@@ -363,6 +365,7 @@ if IMPORT_ALL or TYPE_CHECKING:
363
365
  save_eqx,
364
366
  )
365
367
  from xax.nn.geom import (
368
+ cubic_bezier_interpolation,
366
369
  euler_to_quat,
367
370
  get_projected_gravity_vector_from_quat,
368
371
  quat_to_euler,
@@ -1,10 +1,10 @@
1
1
  """Defines geometry functions."""
2
2
 
3
- import jax
4
3
  from jax import numpy as jnp
4
+ from jaxtyping import Array
5
5
 
6
6
 
7
- def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
7
+ def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
8
8
  """Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
9
9
 
10
10
  Args:
@@ -40,7 +40,7 @@ def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
40
40
  return jnp.concatenate([roll, pitch, yaw], axis=-1)
41
41
 
42
42
 
43
- def euler_to_quat(euler_3: jax.Array) -> jax.Array:
43
+ def euler_to_quat(euler_3: Array) -> Array:
44
44
  """Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
45
45
 
46
46
  Args:
@@ -75,7 +75,7 @@ def euler_to_quat(euler_3: jax.Array) -> jax.Array:
75
75
  return quat
76
76
 
77
77
 
78
- def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -> jax.Array:
78
+ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Array:
79
79
  """Calculates the gravity vector projected onto the local frame given a quaternion orientation.
80
80
 
81
81
  Args:
@@ -101,7 +101,7 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
101
101
  return jnp.concatenate([gx, gy, -gz], axis=-1)
102
102
 
103
103
 
104
- def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
104
+ def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
105
105
  """Rotates a vector by a quaternion.
106
106
 
107
107
  Args:
@@ -156,3 +156,24 @@ def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6)
156
156
  )
157
157
 
158
158
  return jnp.concatenate([xx, yy, zz], axis=-1)
159
+
160
+
161
+ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
162
+ """Cubic bezier interpolation.
163
+
164
+ This is a cubic bezier curve that starts at y_start and ends at y_end,
165
+ and is controlled by the parameter x. The curve is defined by the following formula:
166
+
167
+ y(x) = y_start + (y_end - y_start) * (x**3 + 3 * (x**2 * (1 - x)))
168
+
169
+ Args:
170
+ y_start: The start value, shape (*).
171
+ y_end: The end value, shape (*).
172
+ x: The interpolation parameter, shape (*).
173
+
174
+ Returns:
175
+ The interpolated value, shape (*).
176
+ """
177
+ y_diff = y_end - y_start
178
+ bezier = x**3 + 3 * (x**2 * (1 - x))
179
+ return y_start + y_diff * bezier
@@ -50,8 +50,7 @@ from xax.utils.experiments import (
50
50
  TrainingFinishedError,
51
51
  diff_configs,
52
52
  get_diff_string,
53
- get_git_state,
54
- get_packages_with_versions,
53
+ get_state_file_string,
55
54
  get_training_code,
56
55
  )
57
56
  from xax.utils.jax import jit as xax_jit
@@ -534,9 +533,8 @@ class TrainMixin(
534
533
  logger.log(LOG_STATUS, self.task_path)
535
534
  logger.log(LOG_STATUS, self.task_name)
536
535
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
537
- self.logger.log_file("git_state.txt", get_git_state(self))
538
- self.logger.log_file("packages.txt", get_packages_with_versions())
539
- self.logger.log_file("training_code.txt", get_training_code(self))
536
+ self.logger.log_file("state.txt", get_state_file_string(self))
537
+ self.logger.log_file("training_code.py", get_training_code(self))
540
538
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
541
539
 
542
540
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
@@ -479,6 +479,20 @@ def get_packages_with_versions() -> str:
479
479
  return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
480
480
 
481
481
 
482
+ def get_command_line_string() -> str:
483
+ return " ".join(sys.argv)
484
+
485
+
486
+ def get_state_file_string(obj: object) -> str:
487
+ return "\n\n".join(
488
+ [
489
+ f"=== Command Line ===\n\n{get_command_line_string()}",
490
+ f"=== Git State ===\n\n{get_git_state(obj)}",
491
+ f"=== Packages ===\n\n{get_packages_with_versions()}",
492
+ ]
493
+ )
494
+
495
+
482
496
  def get_training_code(obj: object) -> str:
483
497
  """Gets the text from the file containing the provided object.
484
498
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.14
3
+ Version: 0.1.15
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes