xax 0.2.14__tar.gz → 0.2.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {xax-0.2.14/xax.egg-info → xax-0.2.15}/PKG-INFO +1 -1
  2. {xax-0.2.14 → xax-0.2.15}/xax/__init__.py +15 -5
  3. xax-0.2.15/xax/nn/metrics.py +92 -0
  4. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/train.py +1 -1
  5. {xax-0.2.14 → xax-0.2.15}/xax/utils/pytree.py +10 -0
  6. {xax-0.2.14 → xax-0.2.15/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.2.14 → xax-0.2.15}/xax.egg-info/SOURCES.txt +1 -1
  8. xax-0.2.14/xax/nn/norm.py +0 -24
  9. {xax-0.2.14 → xax-0.2.15}/LICENSE +0 -0
  10. {xax-0.2.14 → xax-0.2.15}/MANIFEST.in +0 -0
  11. {xax-0.2.14 → xax-0.2.15}/README.md +0 -0
  12. {xax-0.2.14 → xax-0.2.15}/pyproject.toml +0 -0
  13. {xax-0.2.14 → xax-0.2.15}/setup.cfg +0 -0
  14. {xax-0.2.14 → xax-0.2.15}/setup.py +0 -0
  15. {xax-0.2.14 → xax-0.2.15}/xax/core/__init__.py +0 -0
  16. {xax-0.2.14 → xax-0.2.15}/xax/core/conf.py +0 -0
  17. {xax-0.2.14 → xax-0.2.15}/xax/core/state.py +0 -0
  18. {xax-0.2.14 → xax-0.2.15}/xax/nn/__init__.py +0 -0
  19. {xax-0.2.14 → xax-0.2.15}/xax/nn/embeddings.py +0 -0
  20. {xax-0.2.14 → xax-0.2.15}/xax/nn/equinox.py +0 -0
  21. {xax-0.2.14 → xax-0.2.15}/xax/nn/export.py +0 -0
  22. {xax-0.2.14 → xax-0.2.15}/xax/nn/functions.py +0 -0
  23. {xax-0.2.14 → xax-0.2.15}/xax/nn/geom.py +0 -0
  24. {xax-0.2.14 → xax-0.2.15}/xax/nn/losses.py +0 -0
  25. {xax-0.2.14 → xax-0.2.15}/xax/nn/parallel.py +0 -0
  26. {xax-0.2.14 → xax-0.2.15}/xax/nn/ssm.py +0 -0
  27. {xax-0.2.14 → xax-0.2.15}/xax/py.typed +0 -0
  28. {xax-0.2.14 → xax-0.2.15}/xax/requirements-dev.txt +0 -0
  29. {xax-0.2.14 → xax-0.2.15}/xax/requirements.txt +0 -0
  30. {xax-0.2.14 → xax-0.2.15}/xax/task/__init__.py +0 -0
  31. {xax-0.2.14 → xax-0.2.15}/xax/task/base.py +0 -0
  32. {xax-0.2.14 → xax-0.2.15}/xax/task/launchers/__init__.py +0 -0
  33. {xax-0.2.14 → xax-0.2.15}/xax/task/launchers/base.py +0 -0
  34. {xax-0.2.14 → xax-0.2.15}/xax/task/launchers/cli.py +0 -0
  35. {xax-0.2.14 → xax-0.2.15}/xax/task/launchers/single_process.py +0 -0
  36. {xax-0.2.14 → xax-0.2.15}/xax/task/logger.py +0 -0
  37. {xax-0.2.14 → xax-0.2.15}/xax/task/loggers/__init__.py +0 -0
  38. {xax-0.2.14 → xax-0.2.15}/xax/task/loggers/callback.py +0 -0
  39. {xax-0.2.14 → xax-0.2.15}/xax/task/loggers/json.py +0 -0
  40. {xax-0.2.14 → xax-0.2.15}/xax/task/loggers/state.py +0 -0
  41. {xax-0.2.14 → xax-0.2.15}/xax/task/loggers/stdout.py +0 -0
  42. {xax-0.2.14 → xax-0.2.15}/xax/task/loggers/tensorboard.py +0 -0
  43. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/__init__.py +0 -0
  44. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/artifacts.py +0 -0
  45. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/checkpointing.py +0 -0
  46. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/compile.py +0 -0
  47. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/cpu_stats.py +0 -0
  48. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/data_loader.py +0 -0
  49. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/gpu_stats.py +0 -0
  50. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/logger.py +0 -0
  51. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/process.py +0 -0
  52. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/runnable.py +0 -0
  53. {xax-0.2.14 → xax-0.2.15}/xax/task/mixins/step_wrapper.py +0 -0
  54. {xax-0.2.14 → xax-0.2.15}/xax/task/script.py +0 -0
  55. {xax-0.2.14 → xax-0.2.15}/xax/task/task.py +0 -0
  56. {xax-0.2.14 → xax-0.2.15}/xax/utils/__init__.py +0 -0
  57. {xax-0.2.14 → xax-0.2.15}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.2.14 → xax-0.2.15}/xax/utils/data/collate.py +0 -0
  59. {xax-0.2.14 → xax-0.2.15}/xax/utils/debugging.py +0 -0
  60. {xax-0.2.14 → xax-0.2.15}/xax/utils/experiments.py +0 -0
  61. {xax-0.2.14 → xax-0.2.15}/xax/utils/jax.py +0 -0
  62. {xax-0.2.14 → xax-0.2.15}/xax/utils/jaxpr.py +0 -0
  63. {xax-0.2.14 → xax-0.2.15}/xax/utils/logging.py +0 -0
  64. {xax-0.2.14 → xax-0.2.15}/xax/utils/numpy.py +0 -0
  65. {xax-0.2.14 → xax-0.2.15}/xax/utils/profile.py +0 -0
  66. {xax-0.2.14 → xax-0.2.15}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.2.14 → xax-0.2.15}/xax/utils/text.py +0 -0
  68. {xax-0.2.14 → xax-0.2.15}/xax/utils/types/__init__.py +0 -0
  69. {xax-0.2.14 → xax-0.2.15}/xax/utils/types/frozen_dict.py +0 -0
  70. {xax-0.2.14 → xax-0.2.15}/xax/utils/types/hashable_array.py +0 -0
  71. {xax-0.2.14 → xax-0.2.15}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.2.14 → xax-0.2.15}/xax.egg-info/requires.txt +0 -0
  73. {xax-0.2.14 → xax-0.2.15}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.14
3
+ Version: 0.2.15
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.14"
15
+ __version__ = "0.2.15"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -51,6 +51,7 @@ __all__ = [
51
51
  "rotation_matrix_to_rotation6d",
52
52
  "cross_entropy",
53
53
  "cast_norm_type",
54
+ "dynamic_time_warping",
54
55
  "get_norm",
55
56
  "is_master",
56
57
  "BaseSSMBlock",
@@ -136,6 +137,7 @@ __all__ = [
136
137
  "reshuffle_pytree_independently",
137
138
  "slice_array",
138
139
  "slice_pytree",
140
+ "tuple_insert",
139
141
  "update_pytree",
140
142
  "TextBlock",
141
143
  "camelcase_to_snakecase",
@@ -229,8 +231,9 @@ NAME_MAP: dict[str, str] = {
229
231
  "rotation6d_to_rotation_matrix": "nn.geom",
230
232
  "rotation_matrix_to_rotation6d": "nn.geom",
231
233
  "cross_entropy": "nn.losses",
232
- "cast_norm_type": "nn.norm",
233
- "get_norm": "nn.norm",
234
+ "cast_norm_type": "nn.metrics",
235
+ "dynamic_time_warping": "nn.metrics",
236
+ "get_norm": "nn.metrics",
234
237
  "is_master": "nn.parallel",
235
238
  "BaseSSMBlock": "nn.ssm",
236
239
  "DiagSSMBlock": "nn.ssm",
@@ -315,6 +318,7 @@ NAME_MAP: dict[str, str] = {
315
318
  "reshuffle_pytree_independently": "utils.pytree",
316
319
  "slice_array": "utils.pytree",
317
320
  "slice_pytree": "utils.pytree",
321
+ "tuple_insert": "utils.pytree",
318
322
  "update_pytree": "utils.pytree",
319
323
  "TextBlock": "utils.text",
320
324
  "camelcase_to_snakecase": "utils.text",
@@ -345,7 +349,7 @@ NAME_MAP.update(
345
349
  "LOG_ERROR_SUMMARY": "utils.logging",
346
350
  "LOG_PING": "utils.logging",
347
351
  "LOG_STATUS": "utils.logging",
348
- "NormType": "nn.norm",
352
+ "NormType": "nn.metrics",
349
353
  "Output": "task.mixins.output",
350
354
  "Phase": "core.state",
351
355
  "RawConfigType": "task.base",
@@ -410,7 +414,12 @@ if IMPORT_ALL or TYPE_CHECKING:
410
414
  rotation_matrix_to_rotation6d,
411
415
  )
412
416
  from xax.nn.losses import cross_entropy
413
- from xax.nn.norm import NormType, cast_norm_type, get_norm
417
+ from xax.nn.metrics import (
418
+ NormType,
419
+ cast_norm_type,
420
+ dynamic_time_warping,
421
+ get_norm,
422
+ )
414
423
  from xax.nn.parallel import is_master
415
424
  from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
416
425
  from xax.task.base import RawConfigType
@@ -495,6 +504,7 @@ if IMPORT_ALL or TYPE_CHECKING:
495
504
  reshuffle_pytree_independently,
496
505
  slice_array,
497
506
  slice_pytree,
507
+ tuple_insert,
498
508
  update_pytree,
499
509
  )
500
510
  from xax.utils.text import (
@@ -0,0 +1,92 @@
1
+ """Norm and metric utilities."""
2
+
3
+ from typing import Literal, cast, get_args, overload
4
+
5
+ import chex
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from jaxtyping import Array
9
+
10
+ from xax.utils.jax import jit as xax_jit
11
+
12
+ NormType = Literal["l1", "l2"]
13
+
14
+
15
+ def cast_norm_type(norm: str) -> NormType:
16
+ if norm not in get_args(NormType):
17
+ raise ValueError(f"Invalid norm: {norm}")
18
+ return cast(NormType, norm)
19
+
20
+
21
+ def get_norm(x: Array, norm: NormType) -> Array:
22
+ match norm:
23
+ case "l1":
24
+ return jnp.abs(x)
25
+ case "l2":
26
+ return jnp.square(x)
27
+ case _:
28
+ raise ValueError(f"Invalid norm: {norm}")
29
+
30
+
31
+ @overload
32
+ def dynamic_time_warping(distance_matrix_nm: Array) -> Array: ...
33
+
34
+
35
+ @overload
36
+ def dynamic_time_warping(distance_matrix_nm: Array, return_path: Literal[True]) -> tuple[Array, Array]: ...
37
+
38
+
39
+ @xax_jit(static_argnames=["return_path"])
40
+ def dynamic_time_warping(distance_matrix_nm: Array, return_path: bool = False) -> Array | tuple[Array, Array]:
41
+ """Dynamic Time Warping.
42
+
43
+ Args:
44
+ distance_matrix_nm: A matrix of pairwise distances between two
45
+ sequences, with shape (N, M), with the condition that N <= M.
46
+ return_path: If set, return the minimum path, otherwise just return
47
+ the cost. The latter is preferred if using this function as a
48
+ distance metric since it avoids the backwards scan on backpointers.
49
+
50
+ Returns:
51
+ The cost of the minimum path from the top-left corner of the distance
52
+ matrix to the bottom-right corner, along with the indices of that
53
+ minimum path.
54
+ """
55
+ chex.assert_shape(distance_matrix_nm, (None, None))
56
+ n, m = distance_matrix_nm.shape
57
+
58
+ assert n <= m, f"Invalid dynamic time warping distance matrix shape: ({n}, {m})"
59
+
60
+ # Masks values which cannot be reached.
61
+ row_idx = jnp.arange(n)[:, None]
62
+ col_idx = jnp.arange(m)[None, :]
63
+ mask = row_idx > col_idx
64
+ distance_matrix_nm = jnp.where(mask, jnp.inf, distance_matrix_nm)
65
+
66
+ # Pre-pads with inf
67
+ distance_matrix_nm = jnp.pad(distance_matrix_nm, ((1, 0), (0, 0)), mode="constant", constant_values=jnp.inf)
68
+ indices = jnp.arange(n)
69
+
70
+ # Scan over remaining rows to fill cost matrix
71
+ def scan_fn(prev_cost: Array, cur_distances: Array) -> tuple[Array, Array]:
72
+ same_trans = prev_cost
73
+ prev_trans = jnp.pad(prev_cost[:-1], ((1, 0),), mode="constant", constant_values=jnp.inf)
74
+ nc = jnp.minimum(prev_trans, same_trans) + cur_distances[1:]
75
+ return nc, jnp.where(prev_trans < same_trans, indices - 1, indices) if return_path else nc
76
+
77
+ init_cost = distance_matrix_nm[1:, 0]
78
+ final_cost, back_pointers = jax.lax.scan(scan_fn, init_cost, distance_matrix_nm[:, 1:].T)
79
+
80
+ if not return_path:
81
+ return final_cost
82
+
83
+ # Scan the back pointers backwards to get the minimum path.
84
+ def scan_back_fn(carry: Array, back_pointer: Array) -> tuple[Array, Array]:
85
+ prev_idx = back_pointer[carry]
86
+ return prev_idx, carry
87
+
88
+ final_index = jnp.array(n - 1)
89
+ _, min_path = jax.lax.scan(scan_back_fn, final_index, back_pointers, reverse=True)
90
+ min_path = jnp.pad(min_path, ((1, 0)), mode="constant", constant_values=0)
91
+
92
+ return final_cost[-1], min_path
@@ -363,7 +363,7 @@ class TrainMixin(
363
363
  self,
364
364
  key: PRNGKeyArray,
365
365
  load_optimizer: Literal[True],
366
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
366
+ ) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State]: ...
367
367
 
368
368
  def load_initial_state(
369
369
  self,
@@ -1,5 +1,7 @@
1
1
  """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
2
 
3
+ from typing import TypeVar
4
+
3
5
  import chex
4
6
  import equinox as eqx
5
7
  import jax
@@ -7,6 +9,8 @@ import jax.numpy as jnp
7
9
  from jax import Array
8
10
  from jaxtyping import PRNGKeyArray, PyTree
9
11
 
12
+ T = TypeVar("T")
13
+
10
14
 
11
15
  def slice_array(x: Array, start: Array, slice_length: int) -> Array:
12
16
  """Get a slice of an array along the first dimension.
@@ -243,3 +247,9 @@ def get_pytree_param_count(pytree: PyTree) -> int:
243
247
  """Calculates the total number of parameters in a PyTree."""
244
248
  leaves, _ = jax.tree.flatten(pytree)
245
249
  return sum(x.size for x in leaves if isinstance(x, jnp.ndarray) and eqx.is_inexact_array(x))
250
+
251
+
252
+ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
253
+ mut = list(t)
254
+ mut[index] = value
255
+ return tuple(mut)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.14
3
+ Version: 0.2.15
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -23,7 +23,7 @@ xax/nn/export.py
23
23
  xax/nn/functions.py
24
24
  xax/nn/geom.py
25
25
  xax/nn/losses.py
26
- xax/nn/norm.py
26
+ xax/nn/metrics.py
27
27
  xax/nn/parallel.py
28
28
  xax/nn/ssm.py
29
29
  xax/task/__init__.py
xax-0.2.14/xax/nn/norm.py DELETED
@@ -1,24 +0,0 @@
1
- """Normalization utilities."""
2
-
3
- from typing import Literal, cast, get_args
4
-
5
- import jax.numpy as jnp
6
- from jaxtyping import Array
7
-
8
- NormType = Literal["l1", "l2"]
9
-
10
-
11
- def cast_norm_type(norm: str) -> NormType:
12
- if norm not in get_args(NormType):
13
- raise ValueError(f"Invalid norm: {norm}")
14
- return cast(NormType, norm)
15
-
16
-
17
- def get_norm(x: Array, norm: NormType) -> Array:
18
- match norm:
19
- case "l1":
20
- return jnp.abs(x)
21
- case "l2":
22
- return jnp.square(x)
23
- case _:
24
- raise ValueError(f"Invalid norm: {norm}")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes