xax 0.2.14__tar.gz → 0.2.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {xax-0.2.14/xax.egg-info → xax-0.2.16}/PKG-INFO +1 -1
  2. {xax-0.2.14 → xax-0.2.16}/xax/__init__.py +15 -5
  3. {xax-0.2.14 → xax-0.2.16}/xax/nn/geom.py +5 -1
  4. xax-0.2.16/xax/nn/metrics.py +92 -0
  5. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/train.py +1 -1
  6. {xax-0.2.14 → xax-0.2.16}/xax/utils/pytree.py +10 -0
  7. {xax-0.2.14 → xax-0.2.16/xax.egg-info}/PKG-INFO +1 -1
  8. {xax-0.2.14 → xax-0.2.16}/xax.egg-info/SOURCES.txt +1 -1
  9. xax-0.2.14/xax/nn/norm.py +0 -24
  10. {xax-0.2.14 → xax-0.2.16}/LICENSE +0 -0
  11. {xax-0.2.14 → xax-0.2.16}/MANIFEST.in +0 -0
  12. {xax-0.2.14 → xax-0.2.16}/README.md +0 -0
  13. {xax-0.2.14 → xax-0.2.16}/pyproject.toml +0 -0
  14. {xax-0.2.14 → xax-0.2.16}/setup.cfg +0 -0
  15. {xax-0.2.14 → xax-0.2.16}/setup.py +0 -0
  16. {xax-0.2.14 → xax-0.2.16}/xax/core/__init__.py +0 -0
  17. {xax-0.2.14 → xax-0.2.16}/xax/core/conf.py +0 -0
  18. {xax-0.2.14 → xax-0.2.16}/xax/core/state.py +0 -0
  19. {xax-0.2.14 → xax-0.2.16}/xax/nn/__init__.py +0 -0
  20. {xax-0.2.14 → xax-0.2.16}/xax/nn/embeddings.py +0 -0
  21. {xax-0.2.14 → xax-0.2.16}/xax/nn/equinox.py +0 -0
  22. {xax-0.2.14 → xax-0.2.16}/xax/nn/export.py +0 -0
  23. {xax-0.2.14 → xax-0.2.16}/xax/nn/functions.py +0 -0
  24. {xax-0.2.14 → xax-0.2.16}/xax/nn/losses.py +0 -0
  25. {xax-0.2.14 → xax-0.2.16}/xax/nn/parallel.py +0 -0
  26. {xax-0.2.14 → xax-0.2.16}/xax/nn/ssm.py +0 -0
  27. {xax-0.2.14 → xax-0.2.16}/xax/py.typed +0 -0
  28. {xax-0.2.14 → xax-0.2.16}/xax/requirements-dev.txt +0 -0
  29. {xax-0.2.14 → xax-0.2.16}/xax/requirements.txt +0 -0
  30. {xax-0.2.14 → xax-0.2.16}/xax/task/__init__.py +0 -0
  31. {xax-0.2.14 → xax-0.2.16}/xax/task/base.py +0 -0
  32. {xax-0.2.14 → xax-0.2.16}/xax/task/launchers/__init__.py +0 -0
  33. {xax-0.2.14 → xax-0.2.16}/xax/task/launchers/base.py +0 -0
  34. {xax-0.2.14 → xax-0.2.16}/xax/task/launchers/cli.py +0 -0
  35. {xax-0.2.14 → xax-0.2.16}/xax/task/launchers/single_process.py +0 -0
  36. {xax-0.2.14 → xax-0.2.16}/xax/task/logger.py +0 -0
  37. {xax-0.2.14 → xax-0.2.16}/xax/task/loggers/__init__.py +0 -0
  38. {xax-0.2.14 → xax-0.2.16}/xax/task/loggers/callback.py +0 -0
  39. {xax-0.2.14 → xax-0.2.16}/xax/task/loggers/json.py +0 -0
  40. {xax-0.2.14 → xax-0.2.16}/xax/task/loggers/state.py +0 -0
  41. {xax-0.2.14 → xax-0.2.16}/xax/task/loggers/stdout.py +0 -0
  42. {xax-0.2.14 → xax-0.2.16}/xax/task/loggers/tensorboard.py +0 -0
  43. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/__init__.py +0 -0
  44. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/artifacts.py +0 -0
  45. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/checkpointing.py +0 -0
  46. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/compile.py +0 -0
  47. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/cpu_stats.py +0 -0
  48. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/data_loader.py +0 -0
  49. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/gpu_stats.py +0 -0
  50. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/logger.py +0 -0
  51. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/process.py +0 -0
  52. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/runnable.py +0 -0
  53. {xax-0.2.14 → xax-0.2.16}/xax/task/mixins/step_wrapper.py +0 -0
  54. {xax-0.2.14 → xax-0.2.16}/xax/task/script.py +0 -0
  55. {xax-0.2.14 → xax-0.2.16}/xax/task/task.py +0 -0
  56. {xax-0.2.14 → xax-0.2.16}/xax/utils/__init__.py +0 -0
  57. {xax-0.2.14 → xax-0.2.16}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.2.14 → xax-0.2.16}/xax/utils/data/collate.py +0 -0
  59. {xax-0.2.14 → xax-0.2.16}/xax/utils/debugging.py +0 -0
  60. {xax-0.2.14 → xax-0.2.16}/xax/utils/experiments.py +0 -0
  61. {xax-0.2.14 → xax-0.2.16}/xax/utils/jax.py +0 -0
  62. {xax-0.2.14 → xax-0.2.16}/xax/utils/jaxpr.py +0 -0
  63. {xax-0.2.14 → xax-0.2.16}/xax/utils/logging.py +0 -0
  64. {xax-0.2.14 → xax-0.2.16}/xax/utils/numpy.py +0 -0
  65. {xax-0.2.14 → xax-0.2.16}/xax/utils/profile.py +0 -0
  66. {xax-0.2.14 → xax-0.2.16}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.2.14 → xax-0.2.16}/xax/utils/text.py +0 -0
  68. {xax-0.2.14 → xax-0.2.16}/xax/utils/types/__init__.py +0 -0
  69. {xax-0.2.14 → xax-0.2.16}/xax/utils/types/frozen_dict.py +0 -0
  70. {xax-0.2.14 → xax-0.2.16}/xax/utils/types/hashable_array.py +0 -0
  71. {xax-0.2.14 → xax-0.2.16}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.2.14 → xax-0.2.16}/xax.egg-info/requires.txt +0 -0
  73. {xax-0.2.14 → xax-0.2.16}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.14"
15
+ __version__ = "0.2.16"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -51,6 +51,7 @@ __all__ = [
51
51
  "rotation_matrix_to_rotation6d",
52
52
  "cross_entropy",
53
53
  "cast_norm_type",
54
+ "dynamic_time_warping",
54
55
  "get_norm",
55
56
  "is_master",
56
57
  "BaseSSMBlock",
@@ -136,6 +137,7 @@ __all__ = [
136
137
  "reshuffle_pytree_independently",
137
138
  "slice_array",
138
139
  "slice_pytree",
140
+ "tuple_insert",
139
141
  "update_pytree",
140
142
  "TextBlock",
141
143
  "camelcase_to_snakecase",
@@ -229,8 +231,9 @@ NAME_MAP: dict[str, str] = {
229
231
  "rotation6d_to_rotation_matrix": "nn.geom",
230
232
  "rotation_matrix_to_rotation6d": "nn.geom",
231
233
  "cross_entropy": "nn.losses",
232
- "cast_norm_type": "nn.norm",
233
- "get_norm": "nn.norm",
234
+ "cast_norm_type": "nn.metrics",
235
+ "dynamic_time_warping": "nn.metrics",
236
+ "get_norm": "nn.metrics",
234
237
  "is_master": "nn.parallel",
235
238
  "BaseSSMBlock": "nn.ssm",
236
239
  "DiagSSMBlock": "nn.ssm",
@@ -315,6 +318,7 @@ NAME_MAP: dict[str, str] = {
315
318
  "reshuffle_pytree_independently": "utils.pytree",
316
319
  "slice_array": "utils.pytree",
317
320
  "slice_pytree": "utils.pytree",
321
+ "tuple_insert": "utils.pytree",
318
322
  "update_pytree": "utils.pytree",
319
323
  "TextBlock": "utils.text",
320
324
  "camelcase_to_snakecase": "utils.text",
@@ -345,7 +349,7 @@ NAME_MAP.update(
345
349
  "LOG_ERROR_SUMMARY": "utils.logging",
346
350
  "LOG_PING": "utils.logging",
347
351
  "LOG_STATUS": "utils.logging",
348
- "NormType": "nn.norm",
352
+ "NormType": "nn.metrics",
349
353
  "Output": "task.mixins.output",
350
354
  "Phase": "core.state",
351
355
  "RawConfigType": "task.base",
@@ -410,7 +414,12 @@ if IMPORT_ALL or TYPE_CHECKING:
410
414
  rotation_matrix_to_rotation6d,
411
415
  )
412
416
  from xax.nn.losses import cross_entropy
413
- from xax.nn.norm import NormType, cast_norm_type, get_norm
417
+ from xax.nn.metrics import (
418
+ NormType,
419
+ cast_norm_type,
420
+ dynamic_time_warping,
421
+ get_norm,
422
+ )
414
423
  from xax.nn.parallel import is_master
415
424
  from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
416
425
  from xax.task.base import RawConfigType
@@ -495,6 +504,7 @@ if IMPORT_ALL or TYPE_CHECKING:
495
504
  reshuffle_pytree_independently,
496
505
  slice_array,
497
506
  slice_pytree,
507
+ tuple_insert,
498
508
  update_pytree,
499
509
  )
500
510
  from xax.utils.text import (
@@ -102,12 +102,13 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
102
102
  return jnp.concatenate([gx, gy, -gz], axis=-1)
103
103
 
104
104
 
105
- def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
105
+ def rotate_vector_by_quat(vector: Array, quat: Array, inverse: bool = False, eps: float = 1e-6) -> Array:
106
106
  """Rotates a vector by a quaternion.
107
107
 
108
108
  Args:
109
109
  vector: The vector to rotate, shape (*, 3).
110
110
  quat: The quaternion to rotate by, shape (*, 4).
111
+ inverse: If True, rotate the vector by the conjugate of the quaternion.
111
112
  eps: A small epsilon value to avoid division by zero.
112
113
 
113
114
  Returns:
@@ -117,6 +118,9 @@ def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Arra
117
118
  quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
118
119
  w, x, y, z = jnp.split(quat, 4, axis=-1)
119
120
 
121
+ if inverse:
122
+ x, y, z = -x, -y, -z
123
+
120
124
  # Extract vector components
121
125
  vx, vy, vz = jnp.split(vector, 3, axis=-1)
122
126
 
@@ -0,0 +1,92 @@
1
+ """Norm and metric utilities."""
2
+
3
+ from typing import Literal, cast, get_args, overload
4
+
5
+ import chex
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jaxtyping import Array
9
+
10
+ from xax.utils.jax import jit as xax_jit
11
+
12
+ NormType = Literal["l1", "l2"]
13
+
14
+
15
+ def cast_norm_type(norm: str) -> NormType:
16
+ if norm not in get_args(NormType):
17
+ raise ValueError(f"Invalid norm: {norm}")
18
+ return cast(NormType, norm)
19
+
20
+
21
+ def get_norm(x: Array, norm: NormType) -> Array:
22
+ match norm:
23
+ case "l1":
24
+ return jnp.abs(x)
25
+ case "l2":
26
+ return jnp.square(x)
27
+ case _:
28
+ raise ValueError(f"Invalid norm: {norm}")
29
+
30
+
31
+ @overload
32
+ def dynamic_time_warping(distance_matrix_nm: Array) -> Array: ...
33
+
34
+
35
+ @overload
36
+ def dynamic_time_warping(distance_matrix_nm: Array, return_path: Literal[True]) -> tuple[Array, Array]: ...
37
+
38
+
39
+ @xax_jit(static_argnames=["return_path"])
40
+ def dynamic_time_warping(distance_matrix_nm: Array, return_path: bool = False) -> Array | tuple[Array, Array]:
41
+ """Dynamic Time Warping.
42
+
43
+ Args:
44
+ distance_matrix_nm: A matrix of pairwise distances between two
45
+ sequences, with shape (N, M), with the condition that N <= M.
46
+ return_path: If set, return the minimum path, otherwise just return
47
+ the cost. The latter is preferred if using this function as a
48
+ distance metric since it avoids the backwards scan on backpointers.
49
+
50
+ Returns:
51
+ The cost of the minimum path from the top-left corner of the distance
52
+ matrix to the bottom-right corner, along with the indices of that
53
+ minimum path.
54
+ """
55
+ chex.assert_shape(distance_matrix_nm, (None, None))
56
+ n, m = distance_matrix_nm.shape
57
+
58
+ assert n <= m, f"Invalid dynamic time warping distance matrix shape: ({n}, {m})"
59
+
60
+ # Masks values which cannot be reached.
61
+ row_idx = jnp.arange(n)[:, None]
62
+ col_idx = jnp.arange(m)[None, :]
63
+ mask = row_idx > col_idx
64
+ distance_matrix_nm = jnp.where(mask, jnp.inf, distance_matrix_nm)
65
+
66
+ # Pre-pads with inf
67
+ distance_matrix_nm = jnp.pad(distance_matrix_nm, ((1, 0), (0, 0)), mode="constant", constant_values=jnp.inf)
68
+ indices = jnp.arange(n)
69
+
70
+ # Scan over remaining rows to fill cost matrix
71
+ def scan_fn(prev_cost: Array, cur_distances: Array) -> tuple[Array, Array]:
72
+ same_trans = prev_cost
73
+ prev_trans = jnp.pad(prev_cost[:-1], ((1, 0),), mode="constant", constant_values=jnp.inf)
74
+ nc = jnp.minimum(prev_trans, same_trans) + cur_distances[1:]
75
+ return nc, jnp.where(prev_trans < same_trans, indices - 1, indices) if return_path else nc
76
+
77
+ init_cost = distance_matrix_nm[1:, 0]
78
+ final_cost, back_pointers = jax.lax.scan(scan_fn, init_cost, distance_matrix_nm[:, 1:].T)
79
+
80
+ if not return_path:
81
+ return final_cost
82
+
83
+ # Scan the back pointers backwards to get the minimum path.
84
+ def scan_back_fn(carry: Array, back_pointer: Array) -> tuple[Array, Array]:
85
+ prev_idx = back_pointer[carry]
86
+ return prev_idx, carry
87
+
88
+ final_index = jnp.array(n - 1)
89
+ _, min_path = jax.lax.scan(scan_back_fn, final_index, back_pointers, reverse=True)
90
+ min_path = jnp.pad(min_path, ((1, 0)), mode="constant", constant_values=0)
91
+
92
+ return final_cost[-1], min_path
@@ -363,7 +363,7 @@ class TrainMixin(
363
363
  self,
364
364
  key: PRNGKeyArray,
365
365
  load_optimizer: Literal[True],
366
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
366
+ ) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State]: ...
367
367
 
368
368
  def load_initial_state(
369
369
  self,
@@ -1,5 +1,7 @@
1
1
  """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
2
 
3
+ from typing import TypeVar
4
+
3
5
  import chex
4
6
  import equinox as eqx
5
7
  import jax
@@ -7,6 +9,8 @@ import jax.numpy as jnp
7
9
  from jax import Array
8
10
  from jaxtyping import PRNGKeyArray, PyTree
9
11
 
12
+ T = TypeVar("T")
13
+
10
14
 
11
15
  def slice_array(x: Array, start: Array, slice_length: int) -> Array:
12
16
  """Get a slice of an array along the first dimension.
@@ -243,3 +247,9 @@ def get_pytree_param_count(pytree: PyTree) -> int:
243
247
  """Calculates the total number of parameters in a PyTree."""
244
248
  leaves, _ = jax.tree.flatten(pytree)
245
249
  return sum(x.size for x in leaves if isinstance(x, jnp.ndarray) and eqx.is_inexact_array(x))
250
+
251
+
252
+ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
253
+ mut = list(t)
254
+ mut[index] = value
255
+ return tuple(mut)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -23,7 +23,7 @@ xax/nn/export.py
23
23
  xax/nn/functions.py
24
24
  xax/nn/geom.py
25
25
  xax/nn/losses.py
26
- xax/nn/norm.py
26
+ xax/nn/metrics.py
27
27
  xax/nn/parallel.py
28
28
  xax/nn/ssm.py
29
29
  xax/task/__init__.py
xax-0.2.14/xax/nn/norm.py DELETED
@@ -1,24 +0,0 @@
1
- """Normalization utilities."""
2
-
3
- from typing import Literal, cast, get_args
4
-
5
- import jax.numpy as jnp
6
- from jaxtyping import Array
7
-
8
- NormType = Literal["l1", "l2"]
9
-
10
-
11
- def cast_norm_type(norm: str) -> NormType:
12
- if norm not in get_args(NormType):
13
- raise ValueError(f"Invalid norm: {norm}")
14
- return cast(NormType, norm)
15
-
16
-
17
- def get_norm(x: Array, norm: NormType) -> Array:
18
- match norm:
19
- case "l1":
20
- return jnp.abs(x)
21
- case "l2":
22
- return jnp.square(x)
23
- case _:
24
- raise ValueError(f"Invalid norm: {norm}")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes