xax 0.1.5__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.5/xax.egg-info → xax-0.1.6}/PKG-INFO +1 -1
- {xax-0.1.5 → xax-0.1.6}/xax/__init__.py +6 -2
- {xax-0.1.5 → xax-0.1.6}/xax/utils/jax.py +22 -0
- {xax-0.1.5 → xax-0.1.6/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.1.5 → xax-0.1.6}/LICENSE +0 -0
- {xax-0.1.5 → xax-0.1.6}/MANIFEST.in +0 -0
- {xax-0.1.5 → xax-0.1.6}/README.md +0 -0
- {xax-0.1.5 → xax-0.1.6}/pyproject.toml +0 -0
- {xax-0.1.5 → xax-0.1.6}/setup.cfg +0 -0
- {xax-0.1.5 → xax-0.1.6}/setup.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/core/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/core/conf.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/core/state.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/nn/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/nn/embeddings.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/nn/equinox.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/nn/export.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/nn/functions.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/nn/geom.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/nn/norm.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/nn/parallel.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/py.typed +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/requirements-dev.txt +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/requirements.txt +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/base.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/launchers/base.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/logger.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/json.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/state.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/stdout.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/compile.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/process.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/train.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/script.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/task/task.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/data/collate.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/debugging.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/experiments.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/logging.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/numpy.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/profile.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/pytree.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/tensorboard.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax/utils/text.py +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax.egg-info/requires.txt +0 -0
- {xax-0.1.5 → xax-0.1.6}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.6"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -97,6 +97,8 @@ __all__ = [
|
|
97
97
|
"save_config",
|
98
98
|
"stage_environment",
|
99
99
|
"to_markdown_table",
|
100
|
+
"HashableArray",
|
101
|
+
"hashable_array",
|
100
102
|
"jit",
|
101
103
|
"save_jaxpr_dot",
|
102
104
|
"ColoredFormatter",
|
@@ -251,6 +253,8 @@ NAME_MAP: dict[str, str] = {
|
|
251
253
|
"save_config": "utils.experiments",
|
252
254
|
"stage_environment": "utils.experiments",
|
253
255
|
"to_markdown_table": "utils.experiments",
|
256
|
+
"HashableArray": "utils.jax",
|
257
|
+
"hashable_array": "utils.jax",
|
254
258
|
"jit": "utils.jax",
|
255
259
|
"save_jaxpr_dot": "utils.jaxpr",
|
256
260
|
"ColoredFormatter": "utils.logging",
|
@@ -404,7 +408,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
404
408
|
stage_environment,
|
405
409
|
to_markdown_table,
|
406
410
|
)
|
407
|
-
from xax.utils.jax import jit
|
411
|
+
from xax.utils.jax import HashableArray, hashable_array, jit
|
408
412
|
from xax.utils.jaxpr import save_jaxpr_dot
|
409
413
|
from xax.utils.logging import (
|
410
414
|
LOG_ERROR_SUMMARY,
|
@@ -138,3 +138,25 @@ def jit(
|
|
138
138
|
return wrapped
|
139
139
|
|
140
140
|
return decorator
|
141
|
+
|
142
|
+
|
143
|
+
class HashableArray:
|
144
|
+
def __init__(self, array: np.ndarray | jnp.ndarray) -> None:
|
145
|
+
if not isinstance(array, (np.ndarray, jnp.ndarray)):
|
146
|
+
raise ValueError(f"Expected np.ndarray or jnp.ndarray, got {type(array)}")
|
147
|
+
self.array = array
|
148
|
+
self._hash: int | None = None
|
149
|
+
|
150
|
+
def __hash__(self) -> int:
|
151
|
+
if self._hash is None:
|
152
|
+
self._hash = hash(self.array.tobytes())
|
153
|
+
return self._hash
|
154
|
+
|
155
|
+
def __eq__(self, other: object) -> bool:
|
156
|
+
if not isinstance(other, HashableArray):
|
157
|
+
return False
|
158
|
+
return bool(jnp.array_equal(self.array, other.array))
|
159
|
+
|
160
|
+
|
161
|
+
def hashable_array(array: np.ndarray | jnp.ndarray) -> HashableArray:
|
162
|
+
return HashableArray(array)
|
{xax-0.1.5 → xax-0.1.6}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.1.5 → xax-0.1.6}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|