xax 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +26 -8
- xax/nn/distributions.py +1 -2
- xax/nn/geom.py +42 -13
- xax/task/mixins/train.py +1 -1
- xax/utils/debugging.py +20 -4
- xax/utils/pytree.py +72 -10
- {xax-0.3.11.dist-info → xax-0.3.13.dist-info}/METADATA +1 -1
- {xax-0.3.11.dist-info → xax-0.3.13.dist-info}/RECORD +12 -12
- {xax-0.3.11.dist-info → xax-0.3.13.dist-info}/WHEEL +0 -0
- {xax-0.3.11.dist-info → xax-0.3.13.dist-info}/entry_points.txt +0 -0
- {xax-0.3.11.dist-info → xax-0.3.13.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.11.dist-info → xax-0.3.13.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.13"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -53,6 +53,7 @@ __all__ = [
|
|
53
53
|
"quat_mul",
|
54
54
|
"quat_to_euler",
|
55
55
|
"quat_to_rotmat",
|
56
|
+
"quat_to_yaw",
|
56
57
|
"rotate_vector_by_quat",
|
57
58
|
"rotation6d_to_rotation_matrix",
|
58
59
|
"rotation_matrix_to_quat",
|
@@ -100,9 +101,9 @@ __all__ = [
|
|
100
101
|
"Task",
|
101
102
|
"collate",
|
102
103
|
"collate_non_null",
|
103
|
-
"
|
104
|
+
"breakpoint_if_nonfinite",
|
104
105
|
"get_named_leaves",
|
105
|
-
"
|
106
|
+
"log_if_nonfinite",
|
106
107
|
"BaseFileDownloader",
|
107
108
|
"ContextTimer",
|
108
109
|
"CumulativeTimer",
|
@@ -138,6 +139,7 @@ __all__ = [
|
|
138
139
|
"worker_chunk",
|
139
140
|
"profile",
|
140
141
|
"compute_nan_ratio",
|
142
|
+
"diff_pytree",
|
141
143
|
"flatten_array",
|
142
144
|
"flatten_pytree",
|
143
145
|
"get_pytree_mapping",
|
@@ -197,7 +199,10 @@ if "XLA_FLAGS" in os.environ:
|
|
197
199
|
# If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
|
198
200
|
# Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
|
199
201
|
if shutil.which("nvidia-smi") is not None:
|
200
|
-
xla_flags += [
|
202
|
+
xla_flags += [
|
203
|
+
"--xla_gpu_enable_latency_hiding_scheduler=true",
|
204
|
+
"--xla_gpu_enable_triton_gemm=false",
|
205
|
+
]
|
201
206
|
os.environ["XLA_FLAGS"] = " ".join(xla_flags)
|
202
207
|
|
203
208
|
# If this flag is set, eagerly imports the entire package (not recommended).
|
@@ -245,6 +250,7 @@ NAME_MAP: dict[str, str] = {
|
|
245
250
|
"quat_mul": "nn.geom",
|
246
251
|
"quat_to_euler": "nn.geom",
|
247
252
|
"quat_to_rotmat": "nn.geom",
|
253
|
+
"quat_to_yaw": "nn.geom",
|
248
254
|
"rotate_vector_by_quat": "nn.geom",
|
249
255
|
"rotation6d_to_rotation_matrix": "nn.geom",
|
250
256
|
"rotation_matrix_to_quat": "nn.geom",
|
@@ -292,9 +298,9 @@ NAME_MAP: dict[str, str] = {
|
|
292
298
|
"Task": "task.task",
|
293
299
|
"collate": "utils.data.collate",
|
294
300
|
"collate_non_null": "utils.data.collate",
|
295
|
-
"
|
301
|
+
"breakpoint_if_nonfinite": "utils.debugging",
|
296
302
|
"get_named_leaves": "utils.debugging",
|
297
|
-
"
|
303
|
+
"log_if_nonfinite": "utils.debugging",
|
298
304
|
"BaseFileDownloader": "utils.experiments",
|
299
305
|
"ContextTimer": "utils.experiments",
|
300
306
|
"CumulativeTimer": "utils.experiments",
|
@@ -330,6 +336,7 @@ NAME_MAP: dict[str, str] = {
|
|
330
336
|
"worker_chunk": "utils.numpy",
|
331
337
|
"profile": "utils.profile",
|
332
338
|
"compute_nan_ratio": "utils.pytree",
|
339
|
+
"diff_pytree": "utils.pytree",
|
333
340
|
"flatten_array": "utils.pytree",
|
334
341
|
"flatten_pytree": "utils.pytree",
|
335
342
|
"get_pytree_mapping": "utils.pytree",
|
@@ -413,7 +420,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
413
420
|
TransformerCache,
|
414
421
|
TransformerStack,
|
415
422
|
)
|
416
|
-
from xax.nn.distributions import
|
423
|
+
from xax.nn.distributions import (
|
424
|
+
Categorical,
|
425
|
+
Distribution,
|
426
|
+
MixtureOfGaussians,
|
427
|
+
Normal,
|
428
|
+
)
|
417
429
|
from xax.nn.embeddings import (
|
418
430
|
EmbeddingKind,
|
419
431
|
FourierEmbeddings,
|
@@ -436,6 +448,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
436
448
|
quat_mul,
|
437
449
|
quat_to_euler,
|
438
450
|
quat_to_rotmat,
|
451
|
+
quat_to_yaw,
|
439
452
|
rotate_vector_by_quat,
|
440
453
|
rotation6d_to_rotation_matrix,
|
441
454
|
rotation_matrix_to_quat,
|
@@ -479,7 +492,11 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
479
492
|
from xax.task.script import Script, ScriptConfig
|
480
493
|
from xax.task.task import Config, Task
|
481
494
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
482
|
-
from xax.utils.debugging import
|
495
|
+
from xax.utils.debugging import (
|
496
|
+
breakpoint_if_nonfinite,
|
497
|
+
get_named_leaves,
|
498
|
+
log_if_nonfinite,
|
499
|
+
)
|
483
500
|
from xax.utils.experiments import (
|
484
501
|
BaseFileDownloader,
|
485
502
|
ContextTimer,
|
@@ -518,6 +535,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
518
535
|
from xax.utils.profile import profile
|
519
536
|
from xax.utils.pytree import (
|
520
537
|
compute_nan_ratio,
|
538
|
+
diff_pytree,
|
521
539
|
flatten_array,
|
522
540
|
flatten_pytree,
|
523
541
|
get_pytree_mapping,
|
xax/nn/distributions.py
CHANGED
@@ -12,7 +12,6 @@ __all__ = [
|
|
12
12
|
"MixtureOfGaussians",
|
13
13
|
]
|
14
14
|
|
15
|
-
import math
|
16
15
|
from abc import ABC, abstractmethod
|
17
16
|
|
18
17
|
import jax
|
@@ -20,7 +19,7 @@ import jax.numpy as jnp
|
|
20
19
|
from jaxtyping import Array, PRNGKeyArray
|
21
20
|
|
22
21
|
STD_CLIP = 1e-6
|
23
|
-
LOGIT_CLIP =
|
22
|
+
LOGIT_CLIP = 6.0
|
24
23
|
|
25
24
|
|
26
25
|
class Distribution(ABC):
|
xax/nn/geom.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Defines geometry functions."""
|
2
2
|
|
3
3
|
import chex
|
4
|
+
import jax
|
4
5
|
from jax import numpy as jnp
|
5
6
|
from jaxtyping import Array
|
6
7
|
|
@@ -15,30 +16,53 @@ def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
|
|
15
16
|
Returns:
|
16
17
|
The roll, pitch, yaw angles with shape (*, 3).
|
17
18
|
"""
|
18
|
-
|
19
|
-
|
19
|
+
# Normalize with clamping
|
20
|
+
norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
|
21
|
+
inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
|
22
|
+
quat_4 = quat_4 * inv_norm
|
23
|
+
|
24
|
+
w, x, y, z = jnp.unstack(quat_4, axis=-1)
|
20
25
|
|
21
26
|
# Roll (x-axis rotation)
|
22
27
|
sinr_cosp = 2.0 * (w * x + y * z)
|
23
28
|
cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
|
24
|
-
roll =
|
29
|
+
roll = jax.lax.atan2(sinr_cosp, cosr_cosp)
|
25
30
|
|
26
31
|
# Pitch (y-axis rotation)
|
27
32
|
sinp = 2.0 * (w * y - z * x)
|
28
|
-
|
29
|
-
|
30
|
-
pitch = jnp.where(
|
31
|
-
jnp.abs(sinp) >= 1.0,
|
32
|
-
jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
|
33
|
-
jnp.arcsin(sinp),
|
34
|
-
)
|
33
|
+
sinp = jnp.clip(sinp, -1.0, 1.0) # Clamp to valid domain
|
34
|
+
pitch = jax.lax.asin(sinp)
|
35
35
|
|
36
36
|
# Yaw (z-axis rotation)
|
37
37
|
siny_cosp = 2.0 * (w * z + x * y)
|
38
38
|
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
39
|
-
yaw =
|
39
|
+
yaw = jax.lax.atan2(siny_cosp, cosy_cosp)
|
40
|
+
|
41
|
+
return jnp.stack([roll, pitch, yaw], axis=-1)
|
42
|
+
|
43
|
+
|
44
|
+
def quat_to_yaw(quat_4: Array, eps: float = 1e-6) -> Array:
|
45
|
+
"""Converts a quaternion to a yaw angle.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
quat_4: The quaternion to convert, shape (*, 4).
|
49
|
+
eps: A small epsilon value to avoid division by zero.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The yaw angle, shape (*).
|
53
|
+
"""
|
54
|
+
# Normalize using a max + safe norm to handle extremely small values robustly
|
55
|
+
norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
|
56
|
+
inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
|
57
|
+
quat_4 = quat_4 * inv_norm
|
58
|
+
|
59
|
+
w, x, y, z = jnp.unstack(quat_4, axis=-1)
|
60
|
+
|
61
|
+
# Compute components with clamping to avoid rounding errors near limits
|
62
|
+
siny_cosp = 2.0 * (w * z + x * y)
|
63
|
+
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
40
64
|
|
41
|
-
return
|
65
|
+
return jax.lax.atan2(siny_cosp, cosy_cosp)
|
42
66
|
|
43
67
|
|
44
68
|
def euler_to_quat(euler_3: Array) -> Array:
|
@@ -89,7 +113,12 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
|
|
89
113
|
return rotate_vector_by_quat(jnp.array([0, 0, -9.81]), quat, inverse=True, eps=eps)
|
90
114
|
|
91
115
|
|
92
|
-
def rotate_vector_by_quat(
|
116
|
+
def rotate_vector_by_quat(
|
117
|
+
vector: Array,
|
118
|
+
quat: Array,
|
119
|
+
inverse: bool = False,
|
120
|
+
eps: float = 1e-6,
|
121
|
+
) -> Array:
|
93
122
|
"""Rotates a vector by a quaternion.
|
94
123
|
|
95
124
|
Args:
|
xax/task/mixins/train.py
CHANGED
@@ -678,7 +678,7 @@ class TrainMixin(
|
|
678
678
|
|
679
679
|
def log_state(self) -> None:
|
680
680
|
logger.log(LOG_STATUS, self.task_path)
|
681
|
-
logger.log(LOG_STATUS, self.
|
681
|
+
logger.log(LOG_STATUS, self.exp_dir)
|
682
682
|
logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
|
683
683
|
self.logger.log_file("state.txt", get_state_file_string(self))
|
684
684
|
self.logger.log_file("training_code.py", get_training_code(self))
|
xax/utils/debugging.py
CHANGED
@@ -51,9 +51,25 @@ def get_named_leaves(
|
|
51
51
|
return ret
|
52
52
|
|
53
53
|
|
54
|
-
def
|
55
|
-
|
54
|
+
def breakpoint_if_nonfinite(x: Array) -> None:
|
55
|
+
is_finite = jnp.isfinite(x).all()
|
56
56
|
|
57
|
+
def true_fn(x: Array) -> None:
|
58
|
+
pass
|
57
59
|
|
58
|
-
def
|
59
|
-
|
60
|
+
def false_fn(x: Array) -> None:
|
61
|
+
jax.debug.breakpoint()
|
62
|
+
|
63
|
+
jax.lax.cond(is_finite, true_fn, false_fn, x)
|
64
|
+
|
65
|
+
|
66
|
+
def log_if_nonfinite(x: Array, loc: str) -> None:
|
67
|
+
is_finite = jnp.isfinite(x).all()
|
68
|
+
|
69
|
+
def true_fn(x: Array) -> None:
|
70
|
+
pass
|
71
|
+
|
72
|
+
def false_fn(x: Array) -> None:
|
73
|
+
jax.debug.print("=== NaNs: {loc} ===", loc=loc)
|
74
|
+
|
75
|
+
jax.lax.cond(is_finite, true_fn, false_fn, x)
|
xax/utils/pytree.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1
1
|
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
2
|
|
3
|
+
from dataclasses import fields, is_dataclass
|
3
4
|
from typing import Mapping, Sequence, TypeVar
|
4
5
|
|
5
6
|
import chex
|
6
7
|
import equinox as eqx
|
7
8
|
import jax
|
8
9
|
import jax.numpy as jnp
|
10
|
+
import numpy as np
|
9
11
|
from jax import Array
|
12
|
+
from jax.core import get_aval
|
10
13
|
from jaxtyping import PRNGKeyArray, PyTree
|
11
14
|
|
12
15
|
T = TypeVar("T")
|
@@ -258,18 +261,77 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
|
|
258
261
|
def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
|
259
262
|
leaves: dict[str, Array] = {}
|
260
263
|
|
261
|
-
def _get_str(thing: PyTree) -> str:
|
262
|
-
if isinstance(thing, str):
|
263
|
-
return thing
|
264
|
-
if isinstance(thing, Sequence):
|
265
|
-
return "/".join(_get_str(x) for x in thing)
|
266
|
-
if isinstance(thing, Mapping):
|
267
|
-
return "/".join(f"{_get_str(k)}:{_get_str(v)}" for k, v in thing.items())
|
268
|
-
return str(thing)
|
269
|
-
|
270
264
|
def _get_leaf(path: tuple, x: PyTree) -> None:
|
271
265
|
if isinstance(x, jnp.ndarray):
|
272
|
-
leaves[
|
266
|
+
leaves[jax.tree_util.keystr(path, simple=True, separator="/")] = x
|
273
267
|
|
274
268
|
jax.tree.map_with_path(_get_leaf, pytree)
|
275
269
|
return leaves
|
270
|
+
|
271
|
+
|
272
|
+
def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
|
273
|
+
diffs = []
|
274
|
+
|
275
|
+
# Handles dataclasses.
|
276
|
+
if is_dataclass(tree_a) and is_dataclass(tree_b):
|
277
|
+
if type(tree_a) is not type(tree_b):
|
278
|
+
diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
|
279
|
+
return diffs
|
280
|
+
for field in fields(tree_a):
|
281
|
+
attr_a, attr_b = getattr(tree_a, field.name), getattr(tree_b, field.name)
|
282
|
+
diffs.extend(diff_pytree(attr_a, attr_b, prefix + f"{field.name}."))
|
283
|
+
return diffs
|
284
|
+
|
285
|
+
# Handle dict-like objects
|
286
|
+
elif isinstance(tree_a, Mapping) and isinstance(tree_b, Mapping):
|
287
|
+
if type(tree_a) is not type(tree_b):
|
288
|
+
diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
|
289
|
+
return diffs
|
290
|
+
keys_a, keys_b = set(tree_a.keys()), set(tree_b.keys())
|
291
|
+
for k in keys_a - keys_b:
|
292
|
+
diffs.append(f"{prefix}{k}: present in A only")
|
293
|
+
for k in keys_b - keys_a:
|
294
|
+
diffs.append(f"{prefix}{k}: present in B only")
|
295
|
+
for k in keys_a & keys_b:
|
296
|
+
diffs.extend(diff_pytree(tree_a[k], tree_b[k], prefix + f"{k}."))
|
297
|
+
return diffs
|
298
|
+
|
299
|
+
# Handle tuple/list
|
300
|
+
elif isinstance(tree_a, Sequence) and isinstance(tree_b, Sequence):
|
301
|
+
if type(tree_a) is not type(tree_b):
|
302
|
+
diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
|
303
|
+
return diffs
|
304
|
+
if len(tree_a) != len(tree_b):
|
305
|
+
diffs.append(f"{prefix}: different lengths {len(tree_a)} vs {len(tree_b)}")
|
306
|
+
for i, (a_i, b_i) in enumerate(zip(tree_a, tree_b, strict=True)):
|
307
|
+
diffs.extend(diff_pytree(a_i, b_i, prefix + f"[{i}]."))
|
308
|
+
return diffs
|
309
|
+
|
310
|
+
# Handles basic types.
|
311
|
+
elif isinstance(tree_a, (int, float, bool, str, type(None), np.number, np.bool, bytes)):
|
312
|
+
if tree_a != tree_b:
|
313
|
+
diffs.append(f"{prefix}: {tree_a!r} vs {tree_b!r}")
|
314
|
+
return diffs
|
315
|
+
|
316
|
+
# Handles Numpy arrays.
|
317
|
+
elif isinstance(tree_a, np.ndarray) and isinstance(tree_b, np.ndarray):
|
318
|
+
if tree_a.shape != tree_b.shape:
|
319
|
+
diffs.append(f"{prefix}: shape {tree_a.shape} vs {tree_b.shape}")
|
320
|
+
if tree_a.dtype != tree_b.dtype:
|
321
|
+
diffs.append(f"{prefix}: dtype {tree_a.dtype} vs {tree_b.dtype}")
|
322
|
+
return diffs
|
323
|
+
|
324
|
+
# Handle arrays (check shape/dtype)
|
325
|
+
elif isinstance(tree_a, jnp.ndarray) and isinstance(tree_b, jnp.ndarray):
|
326
|
+
if tree_a.shape != tree_b.shape:
|
327
|
+
diffs.append(f"{prefix}: shape {tree_a.shape} vs {tree_b.shape}")
|
328
|
+
if tree_a.dtype != tree_b.dtype:
|
329
|
+
diffs.append(f"{prefix}: dtype {tree_a.dtype} vs {tree_b.dtype}")
|
330
|
+
aval_a = get_aval(tree_a)
|
331
|
+
aval_b = get_aval(tree_b)
|
332
|
+
if aval_a != aval_b: # pyright: ignore[reportAttributeAccessIssue]
|
333
|
+
diffs.append(f"{prefix}: aval {aval_a} vs {aval_b}")
|
334
|
+
return diffs
|
335
|
+
|
336
|
+
else:
|
337
|
+
raise ValueError(f"Unknown type: {type(tree_a)}")
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=gTdL72cZZzdpYkHj1Ks981o3nE_BNlvIv1ISYlQarmM,16944
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -9,10 +9,10 @@ xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
|
9
9
|
xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
|
12
|
-
xax/nn/distributions.py,sha256=
|
12
|
+
xax/nn/distributions.py,sha256=6YOjyiPOC7XLDaMYpFNBlLCu3eLgDAeqIg9FoKfYLL4,6497
|
13
13
|
xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
|
14
14
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
15
|
-
xax/nn/geom.py,sha256=
|
15
|
+
xax/nn/geom.py,sha256=ataKbQFXTebK9fM10CFyxsHOPGXhn26P4jakoc9Wqek,11424
|
16
16
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
17
17
|
xax/nn/metrics.py,sha256=zuvPXlRQczBTLHD4ilNGmZaiq6Yie3rxCMq6JkI_kos,3154
|
18
18
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
@@ -43,16 +43,16 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
43
43
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
44
44
|
xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
|
45
45
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
46
|
-
xax/task/mixins/train.py,sha256=
|
46
|
+
xax/task/mixins/train.py,sha256=qb0zpsyeCk_U8Sk8THxtXkUVwj5r0lOlMLNRTctvcWU,32812
|
47
47
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
|
-
xax/utils/debugging.py,sha256=
|
48
|
+
xax/utils/debugging.py,sha256=85JYIdnzLnvXsuli-4YHei_3tE3DnX3rmDSARKW2u1M,2192
|
49
49
|
xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
|
50
50
|
xax/utils/jax.py,sha256=6cP95-rcjkRt1fefkZWJQhJhH0uUYWJB3w4NP1-aDp0,10136
|
51
51
|
xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
52
52
|
xax/utils/logging.py,sha256=Kkyma_LJXqrN2HTQ214gRP_9ih3_bKk115MWC60lQWM,6656
|
53
53
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
54
54
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
55
|
-
xax/utils/pytree.py,sha256=
|
55
|
+
xax/utils/pytree.py,sha256=qC7OfCydX3N5yDIgcWwiXFIdpQZg3uxgBP2H85eNmzQ,12649
|
56
56
|
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
57
57
|
xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
|
58
58
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -60,9 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
60
60
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
61
61
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
62
62
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
68
|
-
xax-0.3.
|
63
|
+
xax-0.3.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
64
|
+
xax-0.3.13.dist-info/METADATA,sha256=Gl4h20HE74S6yx7NlKB64JF1ngQMx7e8gM5uu1SEH-M,1247
|
65
|
+
xax-0.3.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
66
|
+
xax-0.3.13.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
67
|
+
xax-0.3.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
68
|
+
xax-0.3.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|