xax 0.1.5__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {xax-0.1.5/xax.egg-info → xax-0.1.6}/PKG-INFO +1 -1
  2. {xax-0.1.5 → xax-0.1.6}/xax/__init__.py +6 -2
  3. {xax-0.1.5 → xax-0.1.6}/xax/utils/jax.py +22 -0
  4. {xax-0.1.5 → xax-0.1.6/xax.egg-info}/PKG-INFO +1 -1
  5. {xax-0.1.5 → xax-0.1.6}/LICENSE +0 -0
  6. {xax-0.1.5 → xax-0.1.6}/MANIFEST.in +0 -0
  7. {xax-0.1.5 → xax-0.1.6}/README.md +0 -0
  8. {xax-0.1.5 → xax-0.1.6}/pyproject.toml +0 -0
  9. {xax-0.1.5 → xax-0.1.6}/setup.cfg +0 -0
  10. {xax-0.1.5 → xax-0.1.6}/setup.py +0 -0
  11. {xax-0.1.5 → xax-0.1.6}/xax/core/__init__.py +0 -0
  12. {xax-0.1.5 → xax-0.1.6}/xax/core/conf.py +0 -0
  13. {xax-0.1.5 → xax-0.1.6}/xax/core/state.py +0 -0
  14. {xax-0.1.5 → xax-0.1.6}/xax/nn/__init__.py +0 -0
  15. {xax-0.1.5 → xax-0.1.6}/xax/nn/embeddings.py +0 -0
  16. {xax-0.1.5 → xax-0.1.6}/xax/nn/equinox.py +0 -0
  17. {xax-0.1.5 → xax-0.1.6}/xax/nn/export.py +0 -0
  18. {xax-0.1.5 → xax-0.1.6}/xax/nn/functions.py +0 -0
  19. {xax-0.1.5 → xax-0.1.6}/xax/nn/geom.py +0 -0
  20. {xax-0.1.5 → xax-0.1.6}/xax/nn/norm.py +0 -0
  21. {xax-0.1.5 → xax-0.1.6}/xax/nn/parallel.py +0 -0
  22. {xax-0.1.5 → xax-0.1.6}/xax/py.typed +0 -0
  23. {xax-0.1.5 → xax-0.1.6}/xax/requirements-dev.txt +0 -0
  24. {xax-0.1.5 → xax-0.1.6}/xax/requirements.txt +0 -0
  25. {xax-0.1.5 → xax-0.1.6}/xax/task/__init__.py +0 -0
  26. {xax-0.1.5 → xax-0.1.6}/xax/task/base.py +0 -0
  27. {xax-0.1.5 → xax-0.1.6}/xax/task/launchers/__init__.py +0 -0
  28. {xax-0.1.5 → xax-0.1.6}/xax/task/launchers/base.py +0 -0
  29. {xax-0.1.5 → xax-0.1.6}/xax/task/launchers/cli.py +0 -0
  30. {xax-0.1.5 → xax-0.1.6}/xax/task/launchers/single_process.py +0 -0
  31. {xax-0.1.5 → xax-0.1.6}/xax/task/logger.py +0 -0
  32. {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/__init__.py +0 -0
  33. {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/callback.py +0 -0
  34. {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/json.py +0 -0
  35. {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/state.py +0 -0
  36. {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/stdout.py +0 -0
  37. {xax-0.1.5 → xax-0.1.6}/xax/task/loggers/tensorboard.py +0 -0
  38. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/__init__.py +0 -0
  39. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/artifacts.py +0 -0
  40. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/checkpointing.py +0 -0
  41. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/compile.py +0 -0
  42. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/cpu_stats.py +0 -0
  43. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/data_loader.py +0 -0
  44. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/gpu_stats.py +0 -0
  45. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/logger.py +0 -0
  46. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/process.py +0 -0
  47. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/runnable.py +0 -0
  48. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/step_wrapper.py +0 -0
  49. {xax-0.1.5 → xax-0.1.6}/xax/task/mixins/train.py +0 -0
  50. {xax-0.1.5 → xax-0.1.6}/xax/task/script.py +0 -0
  51. {xax-0.1.5 → xax-0.1.6}/xax/task/task.py +0 -0
  52. {xax-0.1.5 → xax-0.1.6}/xax/utils/__init__.py +0 -0
  53. {xax-0.1.5 → xax-0.1.6}/xax/utils/data/__init__.py +0 -0
  54. {xax-0.1.5 → xax-0.1.6}/xax/utils/data/collate.py +0 -0
  55. {xax-0.1.5 → xax-0.1.6}/xax/utils/debugging.py +0 -0
  56. {xax-0.1.5 → xax-0.1.6}/xax/utils/experiments.py +0 -0
  57. {xax-0.1.5 → xax-0.1.6}/xax/utils/jaxpr.py +0 -0
  58. {xax-0.1.5 → xax-0.1.6}/xax/utils/logging.py +0 -0
  59. {xax-0.1.5 → xax-0.1.6}/xax/utils/numpy.py +0 -0
  60. {xax-0.1.5 → xax-0.1.6}/xax/utils/profile.py +0 -0
  61. {xax-0.1.5 → xax-0.1.6}/xax/utils/pytree.py +0 -0
  62. {xax-0.1.5 → xax-0.1.6}/xax/utils/tensorboard.py +0 -0
  63. {xax-0.1.5 → xax-0.1.6}/xax/utils/text.py +0 -0
  64. {xax-0.1.5 → xax-0.1.6}/xax.egg-info/SOURCES.txt +0 -0
  65. {xax-0.1.5 → xax-0.1.6}/xax.egg-info/dependency_links.txt +0 -0
  66. {xax-0.1.5 → xax-0.1.6}/xax.egg-info/requires.txt +0 -0
  67. {xax-0.1.5 → xax-0.1.6}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.5"
15
+ __version__ = "0.1.6"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -97,6 +97,8 @@ __all__ = [
97
97
  "save_config",
98
98
  "stage_environment",
99
99
  "to_markdown_table",
100
+ "HashableArray",
101
+ "hashable_array",
100
102
  "jit",
101
103
  "save_jaxpr_dot",
102
104
  "ColoredFormatter",
@@ -251,6 +253,8 @@ NAME_MAP: dict[str, str] = {
251
253
  "save_config": "utils.experiments",
252
254
  "stage_environment": "utils.experiments",
253
255
  "to_markdown_table": "utils.experiments",
256
+ "HashableArray": "utils.jax",
257
+ "hashable_array": "utils.jax",
254
258
  "jit": "utils.jax",
255
259
  "save_jaxpr_dot": "utils.jaxpr",
256
260
  "ColoredFormatter": "utils.logging",
@@ -404,7 +408,7 @@ if IMPORT_ALL or TYPE_CHECKING:
404
408
  stage_environment,
405
409
  to_markdown_table,
406
410
  )
407
- from xax.utils.jax import jit
411
+ from xax.utils.jax import HashableArray, hashable_array, jit
408
412
  from xax.utils.jaxpr import save_jaxpr_dot
409
413
  from xax.utils.logging import (
410
414
  LOG_ERROR_SUMMARY,
@@ -138,3 +138,25 @@ def jit(
138
138
  return wrapped
139
139
 
140
140
  return decorator
141
+
142
+
143
+ class HashableArray:
144
+ def __init__(self, array: np.ndarray | jnp.ndarray) -> None:
145
+ if not isinstance(array, (np.ndarray, jnp.ndarray)):
146
+ raise ValueError(f"Expected np.ndarray or jnp.ndarray, got {type(array)}")
147
+ self.array = array
148
+ self._hash: int | None = None
149
+
150
+ def __hash__(self) -> int:
151
+ if self._hash is None:
152
+ self._hash = hash(self.array.tobytes())
153
+ return self._hash
154
+
155
+ def __eq__(self, other: object) -> bool:
156
+ if not isinstance(other, HashableArray):
157
+ return False
158
+ return bool(jnp.array_equal(self.array, other.array))
159
+
160
+
161
+ def hashable_array(array: np.ndarray | jnp.ndarray) -> HashableArray:
162
+ return HashableArray(array)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes