xax 0.3.11__tar.gz → 0.3.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {xax-0.3.11/xax.egg-info → xax-0.3.12}/PKG-INFO +1 -1
  2. {xax-0.3.11 → xax-0.3.12}/xax/__init__.py +10 -2
  3. {xax-0.3.11 → xax-0.3.12}/xax/nn/distributions.py +1 -2
  4. {xax-0.3.11 → xax-0.3.12}/xax/utils/pytree.py +74 -10
  5. {xax-0.3.11 → xax-0.3.12/xax.egg-info}/PKG-INFO +1 -1
  6. {xax-0.3.11 → xax-0.3.12}/LICENSE +0 -0
  7. {xax-0.3.11 → xax-0.3.12}/MANIFEST.in +0 -0
  8. {xax-0.3.11 → xax-0.3.12}/README.md +0 -0
  9. {xax-0.3.11 → xax-0.3.12}/pyproject.toml +0 -0
  10. {xax-0.3.11 → xax-0.3.12}/setup.cfg +0 -0
  11. {xax-0.3.11 → xax-0.3.12}/setup.py +0 -0
  12. {xax-0.3.11 → xax-0.3.12}/xax/cli/__init__.py +0 -0
  13. {xax-0.3.11 → xax-0.3.12}/xax/cli/edit_config.py +0 -0
  14. {xax-0.3.11 → xax-0.3.12}/xax/core/__init__.py +0 -0
  15. {xax-0.3.11 → xax-0.3.12}/xax/core/conf.py +0 -0
  16. {xax-0.3.11 → xax-0.3.12}/xax/core/state.py +0 -0
  17. {xax-0.3.11 → xax-0.3.12}/xax/nn/__init__.py +0 -0
  18. {xax-0.3.11 → xax-0.3.12}/xax/nn/attention.py +0 -0
  19. {xax-0.3.11 → xax-0.3.12}/xax/nn/embeddings.py +0 -0
  20. {xax-0.3.11 → xax-0.3.12}/xax/nn/functions.py +0 -0
  21. {xax-0.3.11 → xax-0.3.12}/xax/nn/geom.py +0 -0
  22. {xax-0.3.11 → xax-0.3.12}/xax/nn/losses.py +0 -0
  23. {xax-0.3.11 → xax-0.3.12}/xax/nn/metrics.py +0 -0
  24. {xax-0.3.11 → xax-0.3.12}/xax/nn/parallel.py +0 -0
  25. {xax-0.3.11 → xax-0.3.12}/xax/nn/ssm.py +0 -0
  26. {xax-0.3.11 → xax-0.3.12}/xax/py.typed +0 -0
  27. {xax-0.3.11 → xax-0.3.12}/xax/requirements-dev.txt +0 -0
  28. {xax-0.3.11 → xax-0.3.12}/xax/requirements.txt +0 -0
  29. {xax-0.3.11 → xax-0.3.12}/xax/task/__init__.py +0 -0
  30. {xax-0.3.11 → xax-0.3.12}/xax/task/base.py +0 -0
  31. {xax-0.3.11 → xax-0.3.12}/xax/task/launchers/__init__.py +0 -0
  32. {xax-0.3.11 → xax-0.3.12}/xax/task/launchers/base.py +0 -0
  33. {xax-0.3.11 → xax-0.3.12}/xax/task/launchers/cli.py +0 -0
  34. {xax-0.3.11 → xax-0.3.12}/xax/task/launchers/single_process.py +0 -0
  35. {xax-0.3.11 → xax-0.3.12}/xax/task/logger.py +0 -0
  36. {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/__init__.py +0 -0
  37. {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/callback.py +0 -0
  38. {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/json.py +0 -0
  39. {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/state.py +0 -0
  40. {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/stdout.py +0 -0
  41. {xax-0.3.11 → xax-0.3.12}/xax/task/loggers/tensorboard.py +0 -0
  42. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/__init__.py +0 -0
  43. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/artifacts.py +0 -0
  44. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/checkpointing.py +0 -0
  45. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/compile.py +0 -0
  46. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/cpu_stats.py +0 -0
  47. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/data_loader.py +0 -0
  48. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/gpu_stats.py +0 -0
  49. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/logger.py +0 -0
  50. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/process.py +0 -0
  51. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/runnable.py +0 -0
  52. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/step_wrapper.py +0 -0
  53. {xax-0.3.11 → xax-0.3.12}/xax/task/mixins/train.py +0 -0
  54. {xax-0.3.11 → xax-0.3.12}/xax/task/script.py +0 -0
  55. {xax-0.3.11 → xax-0.3.12}/xax/task/task.py +0 -0
  56. {xax-0.3.11 → xax-0.3.12}/xax/utils/__init__.py +0 -0
  57. {xax-0.3.11 → xax-0.3.12}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.3.11 → xax-0.3.12}/xax/utils/data/collate.py +0 -0
  59. {xax-0.3.11 → xax-0.3.12}/xax/utils/debugging.py +0 -0
  60. {xax-0.3.11 → xax-0.3.12}/xax/utils/experiments.py +0 -0
  61. {xax-0.3.11 → xax-0.3.12}/xax/utils/jax.py +0 -0
  62. {xax-0.3.11 → xax-0.3.12}/xax/utils/jaxpr.py +0 -0
  63. {xax-0.3.11 → xax-0.3.12}/xax/utils/logging.py +0 -0
  64. {xax-0.3.11 → xax-0.3.12}/xax/utils/numpy.py +0 -0
  65. {xax-0.3.11 → xax-0.3.12}/xax/utils/profile.py +0 -0
  66. {xax-0.3.11 → xax-0.3.12}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.3.11 → xax-0.3.12}/xax/utils/text.py +0 -0
  68. {xax-0.3.11 → xax-0.3.12}/xax/utils/types/__init__.py +0 -0
  69. {xax-0.3.11 → xax-0.3.12}/xax/utils/types/frozen_dict.py +0 -0
  70. {xax-0.3.11 → xax-0.3.12}/xax/utils/types/hashable_array.py +0 -0
  71. {xax-0.3.11 → xax-0.3.12}/xax.egg-info/SOURCES.txt +0 -0
  72. {xax-0.3.11 → xax-0.3.12}/xax.egg-info/dependency_links.txt +0 -0
  73. {xax-0.3.11 → xax-0.3.12}/xax.egg-info/entry_points.txt +0 -0
  74. {xax-0.3.11 → xax-0.3.12}/xax.egg-info/requires.txt +0 -0
  75. {xax-0.3.11 → xax-0.3.12}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.11
3
+ Version: 0.3.12
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.11"
15
+ __version__ = "0.3.12"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -138,6 +138,7 @@ __all__ = [
138
138
  "worker_chunk",
139
139
  "profile",
140
140
  "compute_nan_ratio",
141
+ "diff_pytree",
141
142
  "flatten_array",
142
143
  "flatten_pytree",
143
144
  "get_pytree_mapping",
@@ -330,6 +331,7 @@ NAME_MAP: dict[str, str] = {
330
331
  "worker_chunk": "utils.numpy",
331
332
  "profile": "utils.profile",
332
333
  "compute_nan_ratio": "utils.pytree",
334
+ "diff_pytree": "utils.pytree",
333
335
  "flatten_array": "utils.pytree",
334
336
  "flatten_pytree": "utils.pytree",
335
337
  "get_pytree_mapping": "utils.pytree",
@@ -413,7 +415,12 @@ if IMPORT_ALL or TYPE_CHECKING:
413
415
  TransformerCache,
414
416
  TransformerStack,
415
417
  )
416
- from xax.nn.distributions import Categorical, Distribution, MixtureOfGaussians, Normal
418
+ from xax.nn.distributions import (
419
+ Categorical,
420
+ Distribution,
421
+ MixtureOfGaussians,
422
+ Normal,
423
+ )
417
424
  from xax.nn.embeddings import (
418
425
  EmbeddingKind,
419
426
  FourierEmbeddings,
@@ -518,6 +525,7 @@ if IMPORT_ALL or TYPE_CHECKING:
518
525
  from xax.utils.profile import profile
519
526
  from xax.utils.pytree import (
520
527
  compute_nan_ratio,
528
+ diff_pytree,
521
529
  flatten_array,
522
530
  flatten_pytree,
523
531
  get_pytree_mapping,
@@ -12,7 +12,6 @@ __all__ = [
12
12
  "MixtureOfGaussians",
13
13
  ]
14
14
 
15
- import math
16
15
  from abc import ABC, abstractmethod
17
16
 
18
17
  import jax
@@ -20,7 +19,7 @@ import jax.numpy as jnp
20
19
  from jaxtyping import Array, PRNGKeyArray
21
20
 
22
21
  STD_CLIP = 1e-6
23
- LOGIT_CLIP = math.log(1e4)
22
+ LOGIT_CLIP = 6.0
24
23
 
25
24
 
26
25
  class Distribution(ABC):
@@ -1,12 +1,15 @@
1
1
  """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
2
 
3
+ from dataclasses import fields, is_dataclass
3
4
  from typing import Mapping, Sequence, TypeVar
4
5
 
5
6
  import chex
6
7
  import equinox as eqx
7
8
  import jax
8
9
  import jax.numpy as jnp
10
+ import numpy as np
9
11
  from jax import Array
12
+ from jax.core import get_aval
10
13
  from jaxtyping import PRNGKeyArray, PyTree
11
14
 
12
15
  T = TypeVar("T")
@@ -258,18 +261,79 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
258
261
  def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
259
262
  leaves: dict[str, Array] = {}
260
263
 
261
- def _get_str(thing: PyTree) -> str:
262
- if isinstance(thing, str):
263
- return thing
264
- if isinstance(thing, Sequence):
265
- return "/".join(_get_str(x) for x in thing)
266
- if isinstance(thing, Mapping):
267
- return "/".join(f"{_get_str(k)}:{_get_str(v)}" for k, v in thing.items())
268
- return str(thing)
269
-
270
264
  def _get_leaf(path: tuple, x: PyTree) -> None:
271
265
  if isinstance(x, jnp.ndarray):
272
- leaves[_get_str(path)] = x
266
+ leaves[jax.tree_util.keystr(path, simple=True, separator="/")] = x
273
267
 
274
268
  jax.tree.map_with_path(_get_leaf, pytree)
275
269
  return leaves
270
+
271
+
272
+ def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
273
+ diffs = []
274
+
275
+ # Handles dataclasses.
276
+ if is_dataclass(tree_a) and is_dataclass(tree_b):
277
+ for field in fields(tree_a):
278
+ attr_a, attr_b = getattr(tree_a, field.name), getattr(tree_b, field.name)
279
+ diffs.extend(diff_pytree(attr_a, attr_b, prefix + f"{field.name}."))
280
+ return diffs
281
+
282
+ # Handle dict-like objects
283
+ elif isinstance(tree_a, Mapping) and isinstance(tree_b, Mapping):
284
+ if type(tree_a) is not type(tree_b):
285
+ diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
286
+ return diffs
287
+ keys_a, keys_b = set(tree_a.keys()), set(tree_b.keys())
288
+ for k in keys_a - keys_b:
289
+ diffs.append(f"{prefix}{k}: present in A only")
290
+ for k in keys_b - keys_a:
291
+ diffs.append(f"{prefix}{k}: present in B only")
292
+ for k in keys_a & keys_b:
293
+ diffs.extend(diff_pytree(tree_a[k], tree_b[k], prefix + f"{k}."))
294
+ return diffs
295
+
296
+ # Handle tuple/list
297
+ elif isinstance(tree_a, Sequence) and isinstance(tree_b, Sequence):
298
+ if type(tree_a) is not type(tree_b):
299
+ diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
300
+ return diffs
301
+ if len(tree_a) != len(tree_b):
302
+ diffs.append(f"{prefix}: different lengths {len(tree_a)} vs {len(tree_b)}")
303
+ for i, (a_i, b_i) in enumerate(zip(tree_a, tree_b, strict=True)):
304
+ diffs.extend(diff_pytree(a_i, b_i, prefix + f"[{i}]."))
305
+ return diffs
306
+
307
+ # Handles basic types.
308
+ elif isinstance(tree_a, (int, float, bool, str, type(None), np.number, np.bool, bytes)):
309
+ if tree_a != tree_b:
310
+ diffs.append(f"{prefix}: {tree_a!r} vs {tree_b!r}")
311
+ return diffs
312
+
313
+ # Handles Numpy arrays.
314
+ elif isinstance(tree_a, np.ndarray) and isinstance(tree_b, np.ndarray):
315
+ if tree_a.shape != tree_b.shape:
316
+ diffs.append(f"{prefix}: shape {tree_a.shape} vs {tree_b.shape}")
317
+ if tree_a.dtype != tree_b.dtype:
318
+ diffs.append(f"{prefix}: dtype {tree_a.dtype} vs {tree_b.dtype}")
319
+ return diffs
320
+
321
+ # Handle arrays (check shape/dtype)
322
+ elif isinstance(tree_a, jnp.ndarray) and isinstance(tree_b, jnp.ndarray):
323
+ if tree_a.shape != tree_b.shape:
324
+ diffs.append(f"{prefix}: shape {tree_a.shape} vs {tree_b.shape}")
325
+ if tree_a.dtype != tree_b.dtype:
326
+ diffs.append(f"{prefix}: dtype {tree_a.dtype} vs {tree_b.dtype}")
327
+ aval_a = get_aval(tree_a)
328
+ aval_b = get_aval(tree_b)
329
+ if aval_a != aval_b: # pyright: ignore[reportAttributeAccessIssue]
330
+ diffs.append(f"{prefix}: aval {aval_a} vs {aval_b}")
331
+ return diffs
332
+
333
+ # Handle mismatched types
334
+ elif type(tree_a) is not type(tree_b):
335
+ diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
336
+ return diffs
337
+
338
+ else:
339
+ raise ValueError(f"Unknown type: {type(tree_a)}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.11
3
+ Version: 0.3.12
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes