xax 0.2.14__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +15 -5
- xax/nn/metrics.py +92 -0
- xax/task/mixins/train.py +1 -1
- xax/utils/pytree.py +10 -0
- {xax-0.2.14.dist-info → xax-0.2.15.dist-info}/METADATA +1 -1
- {xax-0.2.14.dist-info → xax-0.2.15.dist-info}/RECORD +9 -9
- xax/nn/norm.py +0 -24
- {xax-0.2.14.dist-info → xax-0.2.15.dist-info}/WHEEL +0 -0
- {xax-0.2.14.dist-info → xax-0.2.15.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.14.dist-info → xax-0.2.15.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.15"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -51,6 +51,7 @@ __all__ = [
|
|
51
51
|
"rotation_matrix_to_rotation6d",
|
52
52
|
"cross_entropy",
|
53
53
|
"cast_norm_type",
|
54
|
+
"dynamic_time_warping",
|
54
55
|
"get_norm",
|
55
56
|
"is_master",
|
56
57
|
"BaseSSMBlock",
|
@@ -136,6 +137,7 @@ __all__ = [
|
|
136
137
|
"reshuffle_pytree_independently",
|
137
138
|
"slice_array",
|
138
139
|
"slice_pytree",
|
140
|
+
"tuple_insert",
|
139
141
|
"update_pytree",
|
140
142
|
"TextBlock",
|
141
143
|
"camelcase_to_snakecase",
|
@@ -229,8 +231,9 @@ NAME_MAP: dict[str, str] = {
|
|
229
231
|
"rotation6d_to_rotation_matrix": "nn.geom",
|
230
232
|
"rotation_matrix_to_rotation6d": "nn.geom",
|
231
233
|
"cross_entropy": "nn.losses",
|
232
|
-
"cast_norm_type": "nn.
|
233
|
-
"
|
234
|
+
"cast_norm_type": "nn.metrics",
|
235
|
+
"dynamic_time_warping": "nn.metrics",
|
236
|
+
"get_norm": "nn.metrics",
|
234
237
|
"is_master": "nn.parallel",
|
235
238
|
"BaseSSMBlock": "nn.ssm",
|
236
239
|
"DiagSSMBlock": "nn.ssm",
|
@@ -315,6 +318,7 @@ NAME_MAP: dict[str, str] = {
|
|
315
318
|
"reshuffle_pytree_independently": "utils.pytree",
|
316
319
|
"slice_array": "utils.pytree",
|
317
320
|
"slice_pytree": "utils.pytree",
|
321
|
+
"tuple_insert": "utils.pytree",
|
318
322
|
"update_pytree": "utils.pytree",
|
319
323
|
"TextBlock": "utils.text",
|
320
324
|
"camelcase_to_snakecase": "utils.text",
|
@@ -345,7 +349,7 @@ NAME_MAP.update(
|
|
345
349
|
"LOG_ERROR_SUMMARY": "utils.logging",
|
346
350
|
"LOG_PING": "utils.logging",
|
347
351
|
"LOG_STATUS": "utils.logging",
|
348
|
-
"NormType": "nn.
|
352
|
+
"NormType": "nn.metrics",
|
349
353
|
"Output": "task.mixins.output",
|
350
354
|
"Phase": "core.state",
|
351
355
|
"RawConfigType": "task.base",
|
@@ -410,7 +414,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
410
414
|
rotation_matrix_to_rotation6d,
|
411
415
|
)
|
412
416
|
from xax.nn.losses import cross_entropy
|
413
|
-
from xax.nn.
|
417
|
+
from xax.nn.metrics import (
|
418
|
+
NormType,
|
419
|
+
cast_norm_type,
|
420
|
+
dynamic_time_warping,
|
421
|
+
get_norm,
|
422
|
+
)
|
414
423
|
from xax.nn.parallel import is_master
|
415
424
|
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
416
425
|
from xax.task.base import RawConfigType
|
@@ -495,6 +504,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
495
504
|
reshuffle_pytree_independently,
|
496
505
|
slice_array,
|
497
506
|
slice_pytree,
|
507
|
+
tuple_insert,
|
498
508
|
update_pytree,
|
499
509
|
)
|
500
510
|
from xax.utils.text import (
|
xax/nn/metrics.py
ADDED
@@ -0,0 +1,92 @@
|
|
1
|
+
"""Norm and metric utilities."""
|
2
|
+
|
3
|
+
from typing import Literal, cast, get_args, overload
|
4
|
+
|
5
|
+
import chex
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
from jaxtyping import Array
|
9
|
+
|
10
|
+
from xax.utils.jax import jit as xax_jit
|
11
|
+
|
12
|
+
NormType = Literal["l1", "l2"]
|
13
|
+
|
14
|
+
|
15
|
+
def cast_norm_type(norm: str) -> NormType:
|
16
|
+
if norm not in get_args(NormType):
|
17
|
+
raise ValueError(f"Invalid norm: {norm}")
|
18
|
+
return cast(NormType, norm)
|
19
|
+
|
20
|
+
|
21
|
+
def get_norm(x: Array, norm: NormType) -> Array:
|
22
|
+
match norm:
|
23
|
+
case "l1":
|
24
|
+
return jnp.abs(x)
|
25
|
+
case "l2":
|
26
|
+
return jnp.square(x)
|
27
|
+
case _:
|
28
|
+
raise ValueError(f"Invalid norm: {norm}")
|
29
|
+
|
30
|
+
|
31
|
+
@overload
|
32
|
+
def dynamic_time_warping(distance_matrix_nm: Array) -> Array: ...
|
33
|
+
|
34
|
+
|
35
|
+
@overload
|
36
|
+
def dynamic_time_warping(distance_matrix_nm: Array, return_path: Literal[True]) -> tuple[Array, Array]: ...
|
37
|
+
|
38
|
+
|
39
|
+
@xax_jit(static_argnames=["return_path"])
|
40
|
+
def dynamic_time_warping(distance_matrix_nm: Array, return_path: bool = False) -> Array | tuple[Array, Array]:
|
41
|
+
"""Dynamic Time Warping.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
distance_matrix_nm: A matrix of pairwise distances between two
|
45
|
+
sequences, with shape (N, M), with the condition that N <= M.
|
46
|
+
return_path: If set, return the minimum path, otherwise just return
|
47
|
+
the cost. The latter is preferred if using this function as a
|
48
|
+
distance metric since it avoids the backwards scan on backpointers.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
The cost of the minimum path from the top-left corner of the distance
|
52
|
+
matrix to the bottom-right corner, along with the indices of that
|
53
|
+
minimum path.
|
54
|
+
"""
|
55
|
+
chex.assert_shape(distance_matrix_nm, (None, None))
|
56
|
+
n, m = distance_matrix_nm.shape
|
57
|
+
|
58
|
+
assert n <= m, f"Invalid dynamic time warping distance matrix shape: ({n}, {m})"
|
59
|
+
|
60
|
+
# Masks values which cannot be reached.
|
61
|
+
row_idx = jnp.arange(n)[:, None]
|
62
|
+
col_idx = jnp.arange(m)[None, :]
|
63
|
+
mask = row_idx > col_idx
|
64
|
+
distance_matrix_nm = jnp.where(mask, jnp.inf, distance_matrix_nm)
|
65
|
+
|
66
|
+
# Pre-pads with inf
|
67
|
+
distance_matrix_nm = jnp.pad(distance_matrix_nm, ((1, 0), (0, 0)), mode="constant", constant_values=jnp.inf)
|
68
|
+
indices = jnp.arange(n)
|
69
|
+
|
70
|
+
# Scan over remaining rows to fill cost matrix
|
71
|
+
def scan_fn(prev_cost: Array, cur_distances: Array) -> tuple[Array, Array]:
|
72
|
+
same_trans = prev_cost
|
73
|
+
prev_trans = jnp.pad(prev_cost[:-1], ((1, 0),), mode="constant", constant_values=jnp.inf)
|
74
|
+
nc = jnp.minimum(prev_trans, same_trans) + cur_distances[1:]
|
75
|
+
return nc, jnp.where(prev_trans < same_trans, indices - 1, indices) if return_path else nc
|
76
|
+
|
77
|
+
init_cost = distance_matrix_nm[1:, 0]
|
78
|
+
final_cost, back_pointers = jax.lax.scan(scan_fn, init_cost, distance_matrix_nm[:, 1:].T)
|
79
|
+
|
80
|
+
if not return_path:
|
81
|
+
return final_cost
|
82
|
+
|
83
|
+
# Scan the back pointers backwards to get the minimum path.
|
84
|
+
def scan_back_fn(carry: Array, back_pointer: Array) -> tuple[Array, Array]:
|
85
|
+
prev_idx = back_pointer[carry]
|
86
|
+
return prev_idx, carry
|
87
|
+
|
88
|
+
final_index = jnp.array(n - 1)
|
89
|
+
_, min_path = jax.lax.scan(scan_back_fn, final_index, back_pointers, reverse=True)
|
90
|
+
min_path = jnp.pad(min_path, ((1, 0)), mode="constant", constant_values=0)
|
91
|
+
|
92
|
+
return final_cost[-1], min_path
|
xax/task/mixins/train.py
CHANGED
@@ -363,7 +363,7 @@ class TrainMixin(
|
|
363
363
|
self,
|
364
364
|
key: PRNGKeyArray,
|
365
365
|
load_optimizer: Literal[True],
|
366
|
-
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
|
366
|
+
) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State]: ...
|
367
367
|
|
368
368
|
def load_initial_state(
|
369
369
|
self,
|
xax/utils/pytree.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
2
|
|
3
|
+
from typing import TypeVar
|
4
|
+
|
3
5
|
import chex
|
4
6
|
import equinox as eqx
|
5
7
|
import jax
|
@@ -7,6 +9,8 @@ import jax.numpy as jnp
|
|
7
9
|
from jax import Array
|
8
10
|
from jaxtyping import PRNGKeyArray, PyTree
|
9
11
|
|
12
|
+
T = TypeVar("T")
|
13
|
+
|
10
14
|
|
11
15
|
def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
12
16
|
"""Get a slice of an array along the first dimension.
|
@@ -243,3 +247,9 @@ def get_pytree_param_count(pytree: PyTree) -> int:
|
|
243
247
|
"""Calculates the total number of parameters in a PyTree."""
|
244
248
|
leaves, _ = jax.tree.flatten(pytree)
|
245
249
|
return sum(x.size for x in leaves if isinstance(x, jnp.ndarray) and eqx.is_inexact_array(x))
|
250
|
+
|
251
|
+
|
252
|
+
def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
|
253
|
+
mut = list(t)
|
254
|
+
mut[index] = value
|
255
|
+
return tuple(mut)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=JVxuGfbwBPHXiF4kSG0Pb73mzu3EIaRipjvt0Y-Z9W4,15733
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -12,7 +12,7 @@ xax/nn/export.py,sha256=pRfM2B4hB2EvljysC6AjtgB_7Cn7JtaP3dhYU2stZtY,5545
|
|
12
12
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
13
13
|
xax/nn/geom.py,sha256=B8QE-L-xJWhf9KygTByPUAWe7Clpek4GlTABpsJFMBs,7702
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
|
-
xax/nn/
|
15
|
+
xax/nn/metrics.py,sha256=OAkeScwhi-wTBIJ59KHUhYbZTq4V4V-LG-mKlxMJ7bY,3238
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=sUgZ7_WI4GUreYIDSICpU81IFJNJiHlP0VSv3QFvAB4,33483
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
47
|
xax/utils/experiments.py,sha256=bj8BftSHT3fFzfiJ0Co0WvqWo0rUS8kQnQYpVvH8FTM,29942
|
@@ -50,7 +50,7 @@ xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
|
50
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
51
51
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
52
52
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
53
|
-
xax/utils/pytree.py,sha256=
|
53
|
+
xax/utils/pytree.py,sha256=rVY2kKa637xfX3Oue6OP9ScwmDyxJ_CeHkUpZZtmN04,9231
|
54
54
|
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
55
55
|
xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
|
56
56
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.15.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.15.dist-info/METADATA,sha256=6LJoiKOyNmF1MJSwVSdbEJATzSv1P77Amn4ZJCbWaP0,1880
|
63
|
+
xax-0.2.15.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
64
|
+
xax-0.2.15.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.15.dist-info/RECORD,,
|
xax/nn/norm.py
DELETED
@@ -1,24 +0,0 @@
|
|
1
|
-
"""Normalization utilities."""
|
2
|
-
|
3
|
-
from typing import Literal, cast, get_args
|
4
|
-
|
5
|
-
import jax.numpy as jnp
|
6
|
-
from jaxtyping import Array
|
7
|
-
|
8
|
-
NormType = Literal["l1", "l2"]
|
9
|
-
|
10
|
-
|
11
|
-
def cast_norm_type(norm: str) -> NormType:
|
12
|
-
if norm not in get_args(NormType):
|
13
|
-
raise ValueError(f"Invalid norm: {norm}")
|
14
|
-
return cast(NormType, norm)
|
15
|
-
|
16
|
-
|
17
|
-
def get_norm(x: Array, norm: NormType) -> Array:
|
18
|
-
match norm:
|
19
|
-
case "l1":
|
20
|
-
return jnp.abs(x)
|
21
|
-
case "l2":
|
22
|
-
return jnp.square(x)
|
23
|
-
case _:
|
24
|
-
raise ValueError(f"Invalid norm: {norm}")
|
File without changes
|
File without changes
|
File without changes
|