xax 0.3.9__tar.gz → 0.3.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {xax-0.3.9/xax.egg-info → xax-0.3.10}/PKG-INFO +1 -1
  2. {xax-0.3.9 → xax-0.3.10}/xax/__init__.py +1 -1
  3. {xax-0.3.9 → xax-0.3.10}/xax/utils/pytree.py +11 -4
  4. {xax-0.3.9 → xax-0.3.10/xax.egg-info}/PKG-INFO +1 -1
  5. {xax-0.3.9 → xax-0.3.10}/LICENSE +0 -0
  6. {xax-0.3.9 → xax-0.3.10}/MANIFEST.in +0 -0
  7. {xax-0.3.9 → xax-0.3.10}/README.md +0 -0
  8. {xax-0.3.9 → xax-0.3.10}/pyproject.toml +0 -0
  9. {xax-0.3.9 → xax-0.3.10}/setup.cfg +0 -0
  10. {xax-0.3.9 → xax-0.3.10}/setup.py +0 -0
  11. {xax-0.3.9 → xax-0.3.10}/xax/cli/__init__.py +0 -0
  12. {xax-0.3.9 → xax-0.3.10}/xax/cli/edit_config.py +0 -0
  13. {xax-0.3.9 → xax-0.3.10}/xax/core/__init__.py +0 -0
  14. {xax-0.3.9 → xax-0.3.10}/xax/core/conf.py +0 -0
  15. {xax-0.3.9 → xax-0.3.10}/xax/core/state.py +0 -0
  16. {xax-0.3.9 → xax-0.3.10}/xax/nn/__init__.py +0 -0
  17. {xax-0.3.9 → xax-0.3.10}/xax/nn/attention.py +0 -0
  18. {xax-0.3.9 → xax-0.3.10}/xax/nn/distributions.py +0 -0
  19. {xax-0.3.9 → xax-0.3.10}/xax/nn/embeddings.py +0 -0
  20. {xax-0.3.9 → xax-0.3.10}/xax/nn/functions.py +0 -0
  21. {xax-0.3.9 → xax-0.3.10}/xax/nn/geom.py +0 -0
  22. {xax-0.3.9 → xax-0.3.10}/xax/nn/losses.py +0 -0
  23. {xax-0.3.9 → xax-0.3.10}/xax/nn/metrics.py +0 -0
  24. {xax-0.3.9 → xax-0.3.10}/xax/nn/parallel.py +0 -0
  25. {xax-0.3.9 → xax-0.3.10}/xax/nn/ssm.py +0 -0
  26. {xax-0.3.9 → xax-0.3.10}/xax/py.typed +0 -0
  27. {xax-0.3.9 → xax-0.3.10}/xax/requirements-dev.txt +0 -0
  28. {xax-0.3.9 → xax-0.3.10}/xax/requirements.txt +0 -0
  29. {xax-0.3.9 → xax-0.3.10}/xax/task/__init__.py +0 -0
  30. {xax-0.3.9 → xax-0.3.10}/xax/task/base.py +0 -0
  31. {xax-0.3.9 → xax-0.3.10}/xax/task/launchers/__init__.py +0 -0
  32. {xax-0.3.9 → xax-0.3.10}/xax/task/launchers/base.py +0 -0
  33. {xax-0.3.9 → xax-0.3.10}/xax/task/launchers/cli.py +0 -0
  34. {xax-0.3.9 → xax-0.3.10}/xax/task/launchers/single_process.py +0 -0
  35. {xax-0.3.9 → xax-0.3.10}/xax/task/logger.py +0 -0
  36. {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/__init__.py +0 -0
  37. {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/callback.py +0 -0
  38. {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/json.py +0 -0
  39. {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/state.py +0 -0
  40. {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/stdout.py +0 -0
  41. {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/tensorboard.py +0 -0
  42. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/__init__.py +0 -0
  43. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/artifacts.py +0 -0
  44. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/checkpointing.py +0 -0
  45. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/compile.py +0 -0
  46. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/cpu_stats.py +0 -0
  47. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/data_loader.py +0 -0
  48. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/gpu_stats.py +0 -0
  49. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/logger.py +0 -0
  50. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/process.py +0 -0
  51. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/runnable.py +0 -0
  52. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/step_wrapper.py +0 -0
  53. {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/train.py +0 -0
  54. {xax-0.3.9 → xax-0.3.10}/xax/task/script.py +0 -0
  55. {xax-0.3.9 → xax-0.3.10}/xax/task/task.py +0 -0
  56. {xax-0.3.9 → xax-0.3.10}/xax/utils/__init__.py +0 -0
  57. {xax-0.3.9 → xax-0.3.10}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.3.9 → xax-0.3.10}/xax/utils/data/collate.py +0 -0
  59. {xax-0.3.9 → xax-0.3.10}/xax/utils/debugging.py +0 -0
  60. {xax-0.3.9 → xax-0.3.10}/xax/utils/experiments.py +0 -0
  61. {xax-0.3.9 → xax-0.3.10}/xax/utils/jax.py +0 -0
  62. {xax-0.3.9 → xax-0.3.10}/xax/utils/jaxpr.py +0 -0
  63. {xax-0.3.9 → xax-0.3.10}/xax/utils/logging.py +0 -0
  64. {xax-0.3.9 → xax-0.3.10}/xax/utils/numpy.py +0 -0
  65. {xax-0.3.9 → xax-0.3.10}/xax/utils/profile.py +0 -0
  66. {xax-0.3.9 → xax-0.3.10}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.3.9 → xax-0.3.10}/xax/utils/text.py +0 -0
  68. {xax-0.3.9 → xax-0.3.10}/xax/utils/types/__init__.py +0 -0
  69. {xax-0.3.9 → xax-0.3.10}/xax/utils/types/frozen_dict.py +0 -0
  70. {xax-0.3.9 → xax-0.3.10}/xax/utils/types/hashable_array.py +0 -0
  71. {xax-0.3.9 → xax-0.3.10}/xax.egg-info/SOURCES.txt +0 -0
  72. {xax-0.3.9 → xax-0.3.10}/xax.egg-info/dependency_links.txt +0 -0
  73. {xax-0.3.9 → xax-0.3.10}/xax.egg-info/entry_points.txt +0 -0
  74. {xax-0.3.9 → xax-0.3.10}/xax.egg-info/requires.txt +0 -0
  75. {xax-0.3.9 → xax-0.3.10}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.9
3
+ Version: 0.3.10
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.9"
15
+ __version__ = "0.3.10"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -1,6 +1,6 @@
1
1
  """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
2
 
3
- from typing import TypeVar
3
+ from typing import Mapping, Sequence, TypeVar
4
4
 
5
5
  import chex
6
6
  import equinox as eqx
@@ -258,11 +258,18 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
258
258
  def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
259
259
  leaves: dict[str, Array] = {}
260
260
 
261
+ def _get_str(thing: PyTree) -> str:
262
+ if isinstance(thing, str):
263
+ return thing
264
+ if isinstance(thing, Sequence):
265
+ return "/".join(_get_str(x) for x in thing)
266
+ if isinstance(thing, Mapping):
267
+ return "/".join(f"{_get_str(k)}:{_get_str(v)}" for k, v in thing.items())
268
+ return str(thing)
269
+
261
270
  def _get_leaf(path: tuple, x: PyTree) -> None:
262
271
  if isinstance(x, jnp.ndarray):
263
- # Convert path tuple to string, e.g. (1, 'a', 2) -> '1/a/2'
264
- path_str = "/".join(str(p) for p in path)
265
- leaves[path_str] = x
272
+ leaves[_get_str(path)] = x
266
273
 
267
274
  jax.tree.map_with_path(_get_leaf, pytree)
268
275
  return leaves
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.9
3
+ Version: 0.3.10
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes