xax 0.3.11__tar.gz → 0.3.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.11/xax.egg-info → xax-0.3.12}/PKG-INFO +1 -1
- {xax-0.3.11 → xax-0.3.12}/xax/__init__.py +10 -2
- {xax-0.3.11 → xax-0.3.12}/xax/nn/distributions.py +1 -2
- {xax-0.3.11 → xax-0.3.12}/xax/utils/pytree.py +74 -10
- {xax-0.3.11 → xax-0.3.12/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.3.11 → xax-0.3.12}/LICENSE +0 -0
- {xax-0.3.11 → xax-0.3.12}/MANIFEST.in +0 -0
- {xax-0.3.11 → xax-0.3.12}/README.md +0 -0
- {xax-0.3.11 → xax-0.3.12}/pyproject.toml +0 -0
- {xax-0.3.11 → xax-0.3.12}/setup.cfg +0 -0
- {xax-0.3.11 → xax-0.3.12}/setup.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/cli/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/cli/edit_config.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/core/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/core/conf.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/core/state.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/nn/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/nn/attention.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/nn/embeddings.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/nn/functions.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/nn/geom.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/nn/losses.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/nn/metrics.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/nn/parallel.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/nn/ssm.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/py.typed +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/requirements-dev.txt +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/requirements.txt +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/base.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/launchers/base.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/launchers/single_process.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/logger.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/json.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/state.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/process.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/train.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/script.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/task/task.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/data/collate.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/debugging.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/experiments.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/jax.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/logging.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/numpy.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/profile.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/text.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.11 → xax-0.3.12}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.12"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -138,6 +138,7 @@ __all__ = [
|
|
138
138
|
"worker_chunk",
|
139
139
|
"profile",
|
140
140
|
"compute_nan_ratio",
|
141
|
+
"diff_pytree",
|
141
142
|
"flatten_array",
|
142
143
|
"flatten_pytree",
|
143
144
|
"get_pytree_mapping",
|
@@ -330,6 +331,7 @@ NAME_MAP: dict[str, str] = {
|
|
330
331
|
"worker_chunk": "utils.numpy",
|
331
332
|
"profile": "utils.profile",
|
332
333
|
"compute_nan_ratio": "utils.pytree",
|
334
|
+
"diff_pytree": "utils.pytree",
|
333
335
|
"flatten_array": "utils.pytree",
|
334
336
|
"flatten_pytree": "utils.pytree",
|
335
337
|
"get_pytree_mapping": "utils.pytree",
|
@@ -413,7 +415,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
413
415
|
TransformerCache,
|
414
416
|
TransformerStack,
|
415
417
|
)
|
416
|
-
from xax.nn.distributions import
|
418
|
+
from xax.nn.distributions import (
|
419
|
+
Categorical,
|
420
|
+
Distribution,
|
421
|
+
MixtureOfGaussians,
|
422
|
+
Normal,
|
423
|
+
)
|
417
424
|
from xax.nn.embeddings import (
|
418
425
|
EmbeddingKind,
|
419
426
|
FourierEmbeddings,
|
@@ -518,6 +525,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
518
525
|
from xax.utils.profile import profile
|
519
526
|
from xax.utils.pytree import (
|
520
527
|
compute_nan_ratio,
|
528
|
+
diff_pytree,
|
521
529
|
flatten_array,
|
522
530
|
flatten_pytree,
|
523
531
|
get_pytree_mapping,
|
@@ -12,7 +12,6 @@ __all__ = [
|
|
12
12
|
"MixtureOfGaussians",
|
13
13
|
]
|
14
14
|
|
15
|
-
import math
|
16
15
|
from abc import ABC, abstractmethod
|
17
16
|
|
18
17
|
import jax
|
@@ -20,7 +19,7 @@ import jax.numpy as jnp
|
|
20
19
|
from jaxtyping import Array, PRNGKeyArray
|
21
20
|
|
22
21
|
STD_CLIP = 1e-6
|
23
|
-
LOGIT_CLIP =
|
22
|
+
LOGIT_CLIP = 6.0
|
24
23
|
|
25
24
|
|
26
25
|
class Distribution(ABC):
|
@@ -1,12 +1,15 @@
|
|
1
1
|
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
2
|
|
3
|
+
from dataclasses import fields, is_dataclass
|
3
4
|
from typing import Mapping, Sequence, TypeVar
|
4
5
|
|
5
6
|
import chex
|
6
7
|
import equinox as eqx
|
7
8
|
import jax
|
8
9
|
import jax.numpy as jnp
|
10
|
+
import numpy as np
|
9
11
|
from jax import Array
|
12
|
+
from jax.core import get_aval
|
10
13
|
from jaxtyping import PRNGKeyArray, PyTree
|
11
14
|
|
12
15
|
T = TypeVar("T")
|
@@ -258,18 +261,79 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
|
|
258
261
|
def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
|
259
262
|
leaves: dict[str, Array] = {}
|
260
263
|
|
261
|
-
def _get_str(thing: PyTree) -> str:
|
262
|
-
if isinstance(thing, str):
|
263
|
-
return thing
|
264
|
-
if isinstance(thing, Sequence):
|
265
|
-
return "/".join(_get_str(x) for x in thing)
|
266
|
-
if isinstance(thing, Mapping):
|
267
|
-
return "/".join(f"{_get_str(k)}:{_get_str(v)}" for k, v in thing.items())
|
268
|
-
return str(thing)
|
269
|
-
|
270
264
|
def _get_leaf(path: tuple, x: PyTree) -> None:
|
271
265
|
if isinstance(x, jnp.ndarray):
|
272
|
-
leaves[
|
266
|
+
leaves[jax.tree_util.keystr(path, simple=True, separator="/")] = x
|
273
267
|
|
274
268
|
jax.tree.map_with_path(_get_leaf, pytree)
|
275
269
|
return leaves
|
270
|
+
|
271
|
+
|
272
|
+
def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
|
273
|
+
diffs = []
|
274
|
+
|
275
|
+
# Handles dataclasses.
|
276
|
+
if is_dataclass(tree_a) and is_dataclass(tree_b):
|
277
|
+
for field in fields(tree_a):
|
278
|
+
attr_a, attr_b = getattr(tree_a, field.name), getattr(tree_b, field.name)
|
279
|
+
diffs.extend(diff_pytree(attr_a, attr_b, prefix + f"{field.name}."))
|
280
|
+
return diffs
|
281
|
+
|
282
|
+
# Handle dict-like objects
|
283
|
+
elif isinstance(tree_a, Mapping) and isinstance(tree_b, Mapping):
|
284
|
+
if type(tree_a) is not type(tree_b):
|
285
|
+
diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
|
286
|
+
return diffs
|
287
|
+
keys_a, keys_b = set(tree_a.keys()), set(tree_b.keys())
|
288
|
+
for k in keys_a - keys_b:
|
289
|
+
diffs.append(f"{prefix}{k}: present in A only")
|
290
|
+
for k in keys_b - keys_a:
|
291
|
+
diffs.append(f"{prefix}{k}: present in B only")
|
292
|
+
for k in keys_a & keys_b:
|
293
|
+
diffs.extend(diff_pytree(tree_a[k], tree_b[k], prefix + f"{k}."))
|
294
|
+
return diffs
|
295
|
+
|
296
|
+
# Handle tuple/list
|
297
|
+
elif isinstance(tree_a, Sequence) and isinstance(tree_b, Sequence):
|
298
|
+
if type(tree_a) is not type(tree_b):
|
299
|
+
diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
|
300
|
+
return diffs
|
301
|
+
if len(tree_a) != len(tree_b):
|
302
|
+
diffs.append(f"{prefix}: different lengths {len(tree_a)} vs {len(tree_b)}")
|
303
|
+
for i, (a_i, b_i) in enumerate(zip(tree_a, tree_b, strict=True)):
|
304
|
+
diffs.extend(diff_pytree(a_i, b_i, prefix + f"[{i}]."))
|
305
|
+
return diffs
|
306
|
+
|
307
|
+
# Handles basic types.
|
308
|
+
elif isinstance(tree_a, (int, float, bool, str, type(None), np.number, np.bool, bytes)):
|
309
|
+
if tree_a != tree_b:
|
310
|
+
diffs.append(f"{prefix}: {tree_a!r} vs {tree_b!r}")
|
311
|
+
return diffs
|
312
|
+
|
313
|
+
# Handles Numpy arrays.
|
314
|
+
elif isinstance(tree_a, np.ndarray) and isinstance(tree_b, np.ndarray):
|
315
|
+
if tree_a.shape != tree_b.shape:
|
316
|
+
diffs.append(f"{prefix}: shape {tree_a.shape} vs {tree_b.shape}")
|
317
|
+
if tree_a.dtype != tree_b.dtype:
|
318
|
+
diffs.append(f"{prefix}: dtype {tree_a.dtype} vs {tree_b.dtype}")
|
319
|
+
return diffs
|
320
|
+
|
321
|
+
# Handle arrays (check shape/dtype)
|
322
|
+
elif isinstance(tree_a, jnp.ndarray) and isinstance(tree_b, jnp.ndarray):
|
323
|
+
if tree_a.shape != tree_b.shape:
|
324
|
+
diffs.append(f"{prefix}: shape {tree_a.shape} vs {tree_b.shape}")
|
325
|
+
if tree_a.dtype != tree_b.dtype:
|
326
|
+
diffs.append(f"{prefix}: dtype {tree_a.dtype} vs {tree_b.dtype}")
|
327
|
+
aval_a = get_aval(tree_a)
|
328
|
+
aval_b = get_aval(tree_b)
|
329
|
+
if aval_a != aval_b: # pyright: ignore[reportAttributeAccessIssue]
|
330
|
+
diffs.append(f"{prefix}: aval {aval_a} vs {aval_b}")
|
331
|
+
return diffs
|
332
|
+
|
333
|
+
# Handle mismatched types
|
334
|
+
elif type(tree_a) is not type(tree_b):
|
335
|
+
diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
|
336
|
+
return diffs
|
337
|
+
|
338
|
+
else:
|
339
|
+
raise ValueError(f"Unknown type: {type(tree_a)}")
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|