xax 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +15 -5
- xax/nn/geom.py +5 -1
- xax/nn/metrics.py +92 -0
- xax/task/mixins/train.py +1 -1
- xax/utils/pytree.py +10 -0
- {xax-0.2.14.dist-info → xax-0.2.16.dist-info}/METADATA +1 -1
- {xax-0.2.14.dist-info → xax-0.2.16.dist-info}/RECORD +10 -10
- {xax-0.2.14.dist-info → xax-0.2.16.dist-info}/WHEEL +1 -1
- xax/nn/norm.py +0 -24
- {xax-0.2.14.dist-info → xax-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.14.dist-info → xax-0.2.16.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.16"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -51,6 +51,7 @@ __all__ = [
|
|
51
51
|
"rotation_matrix_to_rotation6d",
|
52
52
|
"cross_entropy",
|
53
53
|
"cast_norm_type",
|
54
|
+
"dynamic_time_warping",
|
54
55
|
"get_norm",
|
55
56
|
"is_master",
|
56
57
|
"BaseSSMBlock",
|
@@ -136,6 +137,7 @@ __all__ = [
|
|
136
137
|
"reshuffle_pytree_independently",
|
137
138
|
"slice_array",
|
138
139
|
"slice_pytree",
|
140
|
+
"tuple_insert",
|
139
141
|
"update_pytree",
|
140
142
|
"TextBlock",
|
141
143
|
"camelcase_to_snakecase",
|
@@ -229,8 +231,9 @@ NAME_MAP: dict[str, str] = {
|
|
229
231
|
"rotation6d_to_rotation_matrix": "nn.geom",
|
230
232
|
"rotation_matrix_to_rotation6d": "nn.geom",
|
231
233
|
"cross_entropy": "nn.losses",
|
232
|
-
"cast_norm_type": "nn.
|
233
|
-
"
|
234
|
+
"cast_norm_type": "nn.metrics",
|
235
|
+
"dynamic_time_warping": "nn.metrics",
|
236
|
+
"get_norm": "nn.metrics",
|
234
237
|
"is_master": "nn.parallel",
|
235
238
|
"BaseSSMBlock": "nn.ssm",
|
236
239
|
"DiagSSMBlock": "nn.ssm",
|
@@ -315,6 +318,7 @@ NAME_MAP: dict[str, str] = {
|
|
315
318
|
"reshuffle_pytree_independently": "utils.pytree",
|
316
319
|
"slice_array": "utils.pytree",
|
317
320
|
"slice_pytree": "utils.pytree",
|
321
|
+
"tuple_insert": "utils.pytree",
|
318
322
|
"update_pytree": "utils.pytree",
|
319
323
|
"TextBlock": "utils.text",
|
320
324
|
"camelcase_to_snakecase": "utils.text",
|
@@ -345,7 +349,7 @@ NAME_MAP.update(
|
|
345
349
|
"LOG_ERROR_SUMMARY": "utils.logging",
|
346
350
|
"LOG_PING": "utils.logging",
|
347
351
|
"LOG_STATUS": "utils.logging",
|
348
|
-
"NormType": "nn.
|
352
|
+
"NormType": "nn.metrics",
|
349
353
|
"Output": "task.mixins.output",
|
350
354
|
"Phase": "core.state",
|
351
355
|
"RawConfigType": "task.base",
|
@@ -410,7 +414,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
410
414
|
rotation_matrix_to_rotation6d,
|
411
415
|
)
|
412
416
|
from xax.nn.losses import cross_entropy
|
413
|
-
from xax.nn.
|
417
|
+
from xax.nn.metrics import (
|
418
|
+
NormType,
|
419
|
+
cast_norm_type,
|
420
|
+
dynamic_time_warping,
|
421
|
+
get_norm,
|
422
|
+
)
|
414
423
|
from xax.nn.parallel import is_master
|
415
424
|
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
416
425
|
from xax.task.base import RawConfigType
|
@@ -495,6 +504,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
495
504
|
reshuffle_pytree_independently,
|
496
505
|
slice_array,
|
497
506
|
slice_pytree,
|
507
|
+
tuple_insert,
|
498
508
|
update_pytree,
|
499
509
|
)
|
500
510
|
from xax.utils.text import (
|
xax/nn/geom.py
CHANGED
@@ -102,12 +102,13 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
|
|
102
102
|
return jnp.concatenate([gx, gy, -gz], axis=-1)
|
103
103
|
|
104
104
|
|
105
|
-
def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
|
105
|
+
def rotate_vector_by_quat(vector: Array, quat: Array, inverse: bool = False, eps: float = 1e-6) -> Array:
|
106
106
|
"""Rotates a vector by a quaternion.
|
107
107
|
|
108
108
|
Args:
|
109
109
|
vector: The vector to rotate, shape (*, 3).
|
110
110
|
quat: The quaternion to rotate by, shape (*, 4).
|
111
|
+
inverse: If True, rotate the vector by the conjugate of the quaternion.
|
111
112
|
eps: A small epsilon value to avoid division by zero.
|
112
113
|
|
113
114
|
Returns:
|
@@ -117,6 +118,9 @@ def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Arra
|
|
117
118
|
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
118
119
|
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
119
120
|
|
121
|
+
if inverse:
|
122
|
+
x, y, z = -x, -y, -z
|
123
|
+
|
120
124
|
# Extract vector components
|
121
125
|
vx, vy, vz = jnp.split(vector, 3, axis=-1)
|
122
126
|
|
xax/nn/metrics.py
ADDED
@@ -0,0 +1,92 @@
|
|
1
|
+
"""Norm and metric utilities."""
|
2
|
+
|
3
|
+
from typing import Literal, cast, get_args, overload
|
4
|
+
|
5
|
+
import chex
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
from jaxtyping import Array
|
9
|
+
|
10
|
+
from xax.utils.jax import jit as xax_jit
|
11
|
+
|
12
|
+
NormType = Literal["l1", "l2"]
|
13
|
+
|
14
|
+
|
15
|
+
def cast_norm_type(norm: str) -> NormType:
|
16
|
+
if norm not in get_args(NormType):
|
17
|
+
raise ValueError(f"Invalid norm: {norm}")
|
18
|
+
return cast(NormType, norm)
|
19
|
+
|
20
|
+
|
21
|
+
def get_norm(x: Array, norm: NormType) -> Array:
|
22
|
+
match norm:
|
23
|
+
case "l1":
|
24
|
+
return jnp.abs(x)
|
25
|
+
case "l2":
|
26
|
+
return jnp.square(x)
|
27
|
+
case _:
|
28
|
+
raise ValueError(f"Invalid norm: {norm}")
|
29
|
+
|
30
|
+
|
31
|
+
@overload
|
32
|
+
def dynamic_time_warping(distance_matrix_nm: Array) -> Array: ...
|
33
|
+
|
34
|
+
|
35
|
+
@overload
|
36
|
+
def dynamic_time_warping(distance_matrix_nm: Array, return_path: Literal[True]) -> tuple[Array, Array]: ...
|
37
|
+
|
38
|
+
|
39
|
+
@xax_jit(static_argnames=["return_path"])
|
40
|
+
def dynamic_time_warping(distance_matrix_nm: Array, return_path: bool = False) -> Array | tuple[Array, Array]:
|
41
|
+
"""Dynamic Time Warping.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
distance_matrix_nm: A matrix of pairwise distances between two
|
45
|
+
sequences, with shape (N, M), with the condition that N <= M.
|
46
|
+
return_path: If set, return the minimum path, otherwise just return
|
47
|
+
the cost. The latter is preferred if using this function as a
|
48
|
+
distance metric since it avoids the backwards scan on backpointers.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
The cost of the minimum path from the top-left corner of the distance
|
52
|
+
matrix to the bottom-right corner, along with the indices of that
|
53
|
+
minimum path.
|
54
|
+
"""
|
55
|
+
chex.assert_shape(distance_matrix_nm, (None, None))
|
56
|
+
n, m = distance_matrix_nm.shape
|
57
|
+
|
58
|
+
assert n <= m, f"Invalid dynamic time warping distance matrix shape: ({n}, {m})"
|
59
|
+
|
60
|
+
# Masks values which cannot be reached.
|
61
|
+
row_idx = jnp.arange(n)[:, None]
|
62
|
+
col_idx = jnp.arange(m)[None, :]
|
63
|
+
mask = row_idx > col_idx
|
64
|
+
distance_matrix_nm = jnp.where(mask, jnp.inf, distance_matrix_nm)
|
65
|
+
|
66
|
+
# Pre-pads with inf
|
67
|
+
distance_matrix_nm = jnp.pad(distance_matrix_nm, ((1, 0), (0, 0)), mode="constant", constant_values=jnp.inf)
|
68
|
+
indices = jnp.arange(n)
|
69
|
+
|
70
|
+
# Scan over remaining rows to fill cost matrix
|
71
|
+
def scan_fn(prev_cost: Array, cur_distances: Array) -> tuple[Array, Array]:
|
72
|
+
same_trans = prev_cost
|
73
|
+
prev_trans = jnp.pad(prev_cost[:-1], ((1, 0),), mode="constant", constant_values=jnp.inf)
|
74
|
+
nc = jnp.minimum(prev_trans, same_trans) + cur_distances[1:]
|
75
|
+
return nc, jnp.where(prev_trans < same_trans, indices - 1, indices) if return_path else nc
|
76
|
+
|
77
|
+
init_cost = distance_matrix_nm[1:, 0]
|
78
|
+
final_cost, back_pointers = jax.lax.scan(scan_fn, init_cost, distance_matrix_nm[:, 1:].T)
|
79
|
+
|
80
|
+
if not return_path:
|
81
|
+
return final_cost
|
82
|
+
|
83
|
+
# Scan the back pointers backwards to get the minimum path.
|
84
|
+
def scan_back_fn(carry: Array, back_pointer: Array) -> tuple[Array, Array]:
|
85
|
+
prev_idx = back_pointer[carry]
|
86
|
+
return prev_idx, carry
|
87
|
+
|
88
|
+
final_index = jnp.array(n - 1)
|
89
|
+
_, min_path = jax.lax.scan(scan_back_fn, final_index, back_pointers, reverse=True)
|
90
|
+
min_path = jnp.pad(min_path, ((1, 0)), mode="constant", constant_values=0)
|
91
|
+
|
92
|
+
return final_cost[-1], min_path
|
xax/task/mixins/train.py
CHANGED
@@ -363,7 +363,7 @@ class TrainMixin(
|
|
363
363
|
self,
|
364
364
|
key: PRNGKeyArray,
|
365
365
|
load_optimizer: Literal[True],
|
366
|
-
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
|
366
|
+
) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State]: ...
|
367
367
|
|
368
368
|
def load_initial_state(
|
369
369
|
self,
|
xax/utils/pytree.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
2
|
|
3
|
+
from typing import TypeVar
|
4
|
+
|
3
5
|
import chex
|
4
6
|
import equinox as eqx
|
5
7
|
import jax
|
@@ -7,6 +9,8 @@ import jax.numpy as jnp
|
|
7
9
|
from jax import Array
|
8
10
|
from jaxtyping import PRNGKeyArray, PyTree
|
9
11
|
|
12
|
+
T = TypeVar("T")
|
13
|
+
|
10
14
|
|
11
15
|
def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
12
16
|
"""Get a slice of an array along the first dimension.
|
@@ -243,3 +247,9 @@ def get_pytree_param_count(pytree: PyTree) -> int:
|
|
243
247
|
"""Calculates the total number of parameters in a PyTree."""
|
244
248
|
leaves, _ = jax.tree.flatten(pytree)
|
245
249
|
return sum(x.size for x in leaves if isinstance(x, jnp.ndarray) and eqx.is_inexact_array(x))
|
250
|
+
|
251
|
+
|
252
|
+
def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
|
253
|
+
mut = list(t)
|
254
|
+
mut[index] = value
|
255
|
+
return tuple(mut)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=c8583PDbOsOtahDvF6sHzP8-VfJlM4M9Bo4G0OHdMmQ,15733
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -10,9 +10,9 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
|
10
10
|
xax/nn/equinox.py,sha256=JZuSApD4bL0UK5W1nrQtucWYvNWUha07J6LTLk_RX-Y,4910
|
11
11
|
xax/nn/export.py,sha256=pRfM2B4hB2EvljysC6AjtgB_7Cn7JtaP3dhYU2stZtY,5545
|
12
12
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
13
|
-
xax/nn/geom.py,sha256=
|
13
|
+
xax/nn/geom.py,sha256=_jCCnUu6HihYVGQPD5f9gZO1tFLrp36NXEki56tk5Q8,7851
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
|
-
xax/nn/
|
15
|
+
xax/nn/metrics.py,sha256=OAkeScwhi-wTBIJ59KHUhYbZTq4V4V-LG-mKlxMJ7bY,3238
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=sUgZ7_WI4GUreYIDSICpU81IFJNJiHlP0VSv3QFvAB4,33483
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
47
|
xax/utils/experiments.py,sha256=bj8BftSHT3fFzfiJ0Co0WvqWo0rUS8kQnQYpVvH8FTM,29942
|
@@ -50,7 +50,7 @@ xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
|
50
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
51
51
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
52
52
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
53
|
-
xax/utils/pytree.py,sha256=
|
53
|
+
xax/utils/pytree.py,sha256=rVY2kKa637xfX3Oue6OP9ScwmDyxJ_CeHkUpZZtmN04,9231
|
54
54
|
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
55
55
|
xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
|
56
56
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.16.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.16.dist-info/METADATA,sha256=N0iLqcRSjmiCna089lGY2ij1lkwCv2wROuv-DcdG4pg,1880
|
63
|
+
xax-0.2.16.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
64
|
+
xax-0.2.16.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.16.dist-info/RECORD,,
|
xax/nn/norm.py
DELETED
@@ -1,24 +0,0 @@
|
|
1
|
-
"""Normalization utilities."""
|
2
|
-
|
3
|
-
from typing import Literal, cast, get_args
|
4
|
-
|
5
|
-
import jax.numpy as jnp
|
6
|
-
from jaxtyping import Array
|
7
|
-
|
8
|
-
NormType = Literal["l1", "l2"]
|
9
|
-
|
10
|
-
|
11
|
-
def cast_norm_type(norm: str) -> NormType:
|
12
|
-
if norm not in get_args(NormType):
|
13
|
-
raise ValueError(f"Invalid norm: {norm}")
|
14
|
-
return cast(NormType, norm)
|
15
|
-
|
16
|
-
|
17
|
-
def get_norm(x: Array, norm: NormType) -> Array:
|
18
|
-
match norm:
|
19
|
-
case "l1":
|
20
|
-
return jnp.abs(x)
|
21
|
-
case "l2":
|
22
|
-
return jnp.square(x)
|
23
|
-
case _:
|
24
|
-
raise ValueError(f"Invalid norm: {norm}")
|
File without changes
|
File without changes
|