xax 0.3.9__tar.gz → 0.3.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.9/xax.egg-info → xax-0.3.10}/PKG-INFO +1 -1
- {xax-0.3.9 → xax-0.3.10}/xax/__init__.py +1 -1
- {xax-0.3.9 → xax-0.3.10}/xax/utils/pytree.py +11 -4
- {xax-0.3.9 → xax-0.3.10/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.3.9 → xax-0.3.10}/LICENSE +0 -0
- {xax-0.3.9 → xax-0.3.10}/MANIFEST.in +0 -0
- {xax-0.3.9 → xax-0.3.10}/README.md +0 -0
- {xax-0.3.9 → xax-0.3.10}/pyproject.toml +0 -0
- {xax-0.3.9 → xax-0.3.10}/setup.cfg +0 -0
- {xax-0.3.9 → xax-0.3.10}/setup.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/cli/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/cli/edit_config.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/core/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/core/conf.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/core/state.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/attention.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/distributions.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/embeddings.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/functions.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/geom.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/losses.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/metrics.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/parallel.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/nn/ssm.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/py.typed +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/requirements-dev.txt +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/requirements.txt +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/base.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/launchers/base.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/launchers/single_process.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/logger.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/json.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/state.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/process.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/mixins/train.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/script.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/task/task.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/data/collate.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/debugging.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/experiments.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/jax.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/logging.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/numpy.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/profile.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/text.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.9 → xax-0.3.10}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
2
|
|
3
|
-
from typing import TypeVar
|
3
|
+
from typing import Mapping, Sequence, TypeVar
|
4
4
|
|
5
5
|
import chex
|
6
6
|
import equinox as eqx
|
@@ -258,11 +258,18 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
|
|
258
258
|
def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
|
259
259
|
leaves: dict[str, Array] = {}
|
260
260
|
|
261
|
+
def _get_str(thing: PyTree) -> str:
|
262
|
+
if isinstance(thing, str):
|
263
|
+
return thing
|
264
|
+
if isinstance(thing, Sequence):
|
265
|
+
return "/".join(_get_str(x) for x in thing)
|
266
|
+
if isinstance(thing, Mapping):
|
267
|
+
return "/".join(f"{_get_str(k)}:{_get_str(v)}" for k, v in thing.items())
|
268
|
+
return str(thing)
|
269
|
+
|
261
270
|
def _get_leaf(path: tuple, x: PyTree) -> None:
|
262
271
|
if isinstance(x, jnp.ndarray):
|
263
|
-
|
264
|
-
path_str = "/".join(str(p) for p in path)
|
265
|
-
leaves[path_str] = x
|
272
|
+
leaves[_get_str(path)] = x
|
266
273
|
|
267
274
|
jax.tree.map_with_path(_get_leaf, pytree)
|
268
275
|
return leaves
|
{xax-0.3.9 → xax-0.3.10}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|