xax 0.1.5__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {xax-0.1.5/xax.egg-info → xax-0.1.7}/PKG-INFO +1 -1
  2. {xax-0.1.5 → xax-0.1.7}/xax/__init__.py +10 -13
  3. {xax-0.1.5 → xax-0.1.7}/xax/nn/equinox.py +2 -2
  4. {xax-0.1.5 → xax-0.1.7}/xax/nn/export.py +5 -5
  5. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/compile.py +1 -1
  6. xax-0.1.7/xax/utils/types/__init__.py +0 -0
  7. xax-0.1.7/xax/utils/types/frozen_dict.py +148 -0
  8. xax-0.1.7/xax/utils/types/hashable_array.py +31 -0
  9. {xax-0.1.5 → xax-0.1.7/xax.egg-info}/PKG-INFO +1 -1
  10. {xax-0.1.5 → xax-0.1.7}/xax.egg-info/SOURCES.txt +4 -1
  11. {xax-0.1.5 → xax-0.1.7}/LICENSE +0 -0
  12. {xax-0.1.5 → xax-0.1.7}/MANIFEST.in +0 -0
  13. {xax-0.1.5 → xax-0.1.7}/README.md +0 -0
  14. {xax-0.1.5 → xax-0.1.7}/pyproject.toml +0 -0
  15. {xax-0.1.5 → xax-0.1.7}/setup.cfg +0 -0
  16. {xax-0.1.5 → xax-0.1.7}/setup.py +0 -0
  17. {xax-0.1.5 → xax-0.1.7}/xax/core/__init__.py +0 -0
  18. {xax-0.1.5 → xax-0.1.7}/xax/core/conf.py +0 -0
  19. {xax-0.1.5 → xax-0.1.7}/xax/core/state.py +0 -0
  20. {xax-0.1.5 → xax-0.1.7}/xax/nn/__init__.py +0 -0
  21. {xax-0.1.5 → xax-0.1.7}/xax/nn/embeddings.py +0 -0
  22. {xax-0.1.5 → xax-0.1.7}/xax/nn/functions.py +0 -0
  23. {xax-0.1.5 → xax-0.1.7}/xax/nn/geom.py +0 -0
  24. {xax-0.1.5 → xax-0.1.7}/xax/nn/norm.py +0 -0
  25. {xax-0.1.5 → xax-0.1.7}/xax/nn/parallel.py +0 -0
  26. {xax-0.1.5 → xax-0.1.7}/xax/py.typed +0 -0
  27. {xax-0.1.5 → xax-0.1.7}/xax/requirements-dev.txt +0 -0
  28. {xax-0.1.5 → xax-0.1.7}/xax/requirements.txt +0 -0
  29. {xax-0.1.5 → xax-0.1.7}/xax/task/__init__.py +0 -0
  30. {xax-0.1.5 → xax-0.1.7}/xax/task/base.py +0 -0
  31. {xax-0.1.5 → xax-0.1.7}/xax/task/launchers/__init__.py +0 -0
  32. {xax-0.1.5 → xax-0.1.7}/xax/task/launchers/base.py +0 -0
  33. {xax-0.1.5 → xax-0.1.7}/xax/task/launchers/cli.py +0 -0
  34. {xax-0.1.5 → xax-0.1.7}/xax/task/launchers/single_process.py +0 -0
  35. {xax-0.1.5 → xax-0.1.7}/xax/task/logger.py +0 -0
  36. {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/__init__.py +0 -0
  37. {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/callback.py +0 -0
  38. {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/json.py +0 -0
  39. {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/state.py +0 -0
  40. {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/stdout.py +0 -0
  41. {xax-0.1.5 → xax-0.1.7}/xax/task/loggers/tensorboard.py +0 -0
  42. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/__init__.py +0 -0
  43. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/artifacts.py +0 -0
  44. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/checkpointing.py +0 -0
  45. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/cpu_stats.py +0 -0
  46. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/gpu_stats.py +0 -0
  48. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/process.py +0 -0
  50. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/runnable.py +0 -0
  51. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/step_wrapper.py +0 -0
  52. {xax-0.1.5 → xax-0.1.7}/xax/task/mixins/train.py +0 -0
  53. {xax-0.1.5 → xax-0.1.7}/xax/task/script.py +0 -0
  54. {xax-0.1.5 → xax-0.1.7}/xax/task/task.py +0 -0
  55. {xax-0.1.5 → xax-0.1.7}/xax/utils/__init__.py +0 -0
  56. {xax-0.1.5 → xax-0.1.7}/xax/utils/data/__init__.py +0 -0
  57. {xax-0.1.5 → xax-0.1.7}/xax/utils/data/collate.py +0 -0
  58. {xax-0.1.5 → xax-0.1.7}/xax/utils/debugging.py +0 -0
  59. {xax-0.1.5 → xax-0.1.7}/xax/utils/experiments.py +0 -0
  60. {xax-0.1.5 → xax-0.1.7}/xax/utils/jax.py +0 -0
  61. {xax-0.1.5 → xax-0.1.7}/xax/utils/jaxpr.py +0 -0
  62. {xax-0.1.5 → xax-0.1.7}/xax/utils/logging.py +0 -0
  63. {xax-0.1.5 → xax-0.1.7}/xax/utils/numpy.py +0 -0
  64. {xax-0.1.5 → xax-0.1.7}/xax/utils/profile.py +0 -0
  65. {xax-0.1.5 → xax-0.1.7}/xax/utils/pytree.py +0 -0
  66. {xax-0.1.5 → xax-0.1.7}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.1.5 → xax-0.1.7}/xax/utils/text.py +0 -0
  68. {xax-0.1.5 → xax-0.1.7}/xax.egg-info/dependency_links.txt +0 -0
  69. {xax-0.1.5 → xax-0.1.7}/xax.egg-info/requires.txt +0 -0
  70. {xax-0.1.5 → xax-0.1.7}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.5"
15
+ __version__ = "0.1.7"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -41,9 +41,6 @@ __all__ = [
41
41
  "load_eqx_mlp",
42
42
  "make_eqx_mlp",
43
43
  "save_eqx",
44
- "export",
45
- "export_flax",
46
- "export_with_params",
47
44
  "euler_to_quat",
48
45
  "get_projected_gravity_vector_from_quat",
49
46
  "quat_to_euler",
@@ -130,6 +127,9 @@ __all__ = [
130
127
  "snakecase_to_camelcase",
131
128
  "uncolored",
132
129
  "wrapped",
130
+ "FrozenDict",
131
+ "HashableArray",
132
+ "hashable_array",
133
133
  ]
134
134
 
135
135
  __all__ += [
@@ -159,7 +159,7 @@ if "XLA_FLAGS" in os.environ:
159
159
  # If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
160
160
  # Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
161
161
  if shutil.which("nvidia-smi") is not None:
162
- xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler", "--xla_gpu_enable_triton_gemm"]
162
+ xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler=true", "--xla_gpu_enable_triton_gemm=false"]
163
163
  os.environ["XLA_FLAGS"] = " ".join(xla_flags)
164
164
 
165
165
  # If this flag is set, eagerly imports the entire package (not recommended).
@@ -195,9 +195,6 @@ NAME_MAP: dict[str, str] = {
195
195
  "load_eqx_mlp": "nn.equinox",
196
196
  "make_eqx_mlp": "nn.equinox",
197
197
  "save_eqx": "nn.equinox",
198
- "export": "nn.export",
199
- "export_flax": "nn.export",
200
- "export_with_params": "nn.export",
201
198
  "euler_to_quat": "nn.geom",
202
199
  "get_projected_gravity_vector_from_quat": "nn.geom",
203
200
  "quat_to_euler": "nn.geom",
@@ -284,6 +281,9 @@ NAME_MAP: dict[str, str] = {
284
281
  "snakecase_to_camelcase": "utils.text",
285
282
  "uncolored": "utils.text",
286
283
  "wrapped": "utils.text",
284
+ "FrozenDict": "utils.types.frozen_dict",
285
+ "HashableArray": "utils.types.hashable_array",
286
+ "hashable_array": "utils.types.hashable_array",
287
287
  }
288
288
 
289
289
  # Need to manually set some values which can't be auto-generated.
@@ -348,11 +348,6 @@ if IMPORT_ALL or TYPE_CHECKING:
348
348
  make_eqx_mlp,
349
349
  save_eqx,
350
350
  )
351
- from xax.nn.export import (
352
- export,
353
- export_flax,
354
- export_with_params,
355
- )
356
351
  from xax.nn.geom import (
357
352
  euler_to_quat,
358
353
  get_projected_gravity_vector_from_quat,
@@ -444,5 +439,7 @@ if IMPORT_ALL or TYPE_CHECKING:
444
439
  uncolored,
445
440
  wrapped,
446
441
  )
442
+ from xax.utils.types.frozen_dict import FrozenDict
443
+ from xax.utils.types.hashable_array import HashableArray, hashable_array
447
444
 
448
445
  del TYPE_CHECKING, IMPORT_ALL
@@ -72,7 +72,7 @@ def _infer_activation(activation: ActivationFunction) -> Callable:
72
72
  raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
73
73
 
74
74
 
75
- def make_eqx_mlp(hyperparams: MLPHyperParams, key: PRNGKeyArray = jax.random.PRNGKey(0)) -> eqx.nn.MLP:
75
+ def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
76
76
  """Create an Equinox MLP from a set of hyperparameters.
77
77
 
78
78
  Args:
@@ -176,5 +176,5 @@ def load_eqx_mlp(
176
176
  ) -> eqx.nn.MLP:
177
177
  with open(eqx_file, "rb") as f:
178
178
  hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
179
- model = make_eqx_mlp(hyperparams=hyperparams)
179
+ model = make_eqx_mlp(hyperparams=hyperparams, key=jax.random.PRNGKey(0))
180
180
  return eqx.tree_deserialise_leaves(f, model)
@@ -4,18 +4,18 @@ import logging
4
4
  from pathlib import Path
5
5
  from typing import Callable
6
6
 
7
- import flax
8
7
  import jax
9
- import tensorflow as tf
10
- from jax.experimental import jax2tf
11
8
  from jaxtyping import Array, PyTree
12
9
 
13
10
  try:
11
+ import flax
12
+ import tensorflow as tf
13
+ from jax.experimental import jax2tf
14
14
  from orbax.export import ExportManager, JaxModule, ServingConfig
15
15
  except ImportError as e:
16
16
  raise ImportError(
17
- "Please install the package with `orbax` as a dependency, using "
18
- "'xax[export]` to install the required dependencies."
17
+ "In order to export models, please install Xax with export dependencies, "
18
+ "using 'xax[export]` to install the required dependencies."
19
19
  ) from e
20
20
 
21
21
  logger = logging.getLogger(__name__)
@@ -22,7 +22,7 @@ def get_cache_dir() -> str | None:
22
22
  # By default, only cache on MacOS, since Jax caching on Linux is very
23
23
  # prone to NaNs.
24
24
  match sys.platform:
25
- case "darwin":
25
+ case "darwin" | "linux":
26
26
  return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
27
27
  case _:
28
28
  return None
File without changes
@@ -0,0 +1,148 @@
1
+ """Defines a frozen dictionary type.
2
+
3
+ This is mostly taken from Flax - we move it here to avoid having to use Flax as
4
+ a dependency in downstream projects.
5
+ """
6
+
7
+ import collections
8
+ from types import MappingProxyType
9
+ from typing import Any, Iterator, Mapping, Self, TypeVar
10
+
11
+ import jax
12
+
13
+ K = TypeVar("K")
14
+ V = TypeVar("V")
15
+
16
+
17
+ def _prepare_freeze(xs: Any) -> Any: # noqa: ANN401
18
+ """Deep copy unfrozen dicts to make the dictionary FrozenDict safe."""
19
+ if isinstance(xs, FrozenDict):
20
+ return xs._dict # pylint: disable=protected-access
21
+ if not isinstance(xs, dict):
22
+ return xs
23
+ return {key: _prepare_freeze(val) for key, val in xs.items()}
24
+
25
+
26
+ def _indent(x: str, num_spaces: int) -> str:
27
+ indent_str = " " * num_spaces
28
+ lines = x.split("\n")
29
+ assert not lines[-1]
30
+ return "\n".join(indent_str + line for line in lines[:-1]) + "\n"
31
+
32
+
33
+ class FrozenKeysView(collections.abc.KeysView[K]):
34
+ def __repr__(self) -> str:
35
+ return f"frozen_dict_keys({list(self)})"
36
+
37
+
38
+ class FrozenValuesView(collections.abc.ValuesView[V]):
39
+ def __repr__(self) -> str:
40
+ return f"frozen_dict_values({list(self)})"
41
+
42
+
43
+ @jax.tree_util.register_pytree_with_keys_class
44
+ class FrozenDict(Mapping[K, V]):
45
+ """An immutable variant of the Python dict."""
46
+
47
+ __slots__ = ("_dict", "_hash")
48
+
49
+ def __init__(self, *args: Any, __unsafe_skip_copy__: bool = False, **kwargs: Any) -> None: # noqa: ANN401
50
+ # make sure the dict is as
51
+ xs = dict(*args, **kwargs)
52
+ if __unsafe_skip_copy__:
53
+ self._dict = xs
54
+ else:
55
+ self._dict = _prepare_freeze(xs)
56
+
57
+ self._hash: int | None = None
58
+
59
+ def __getitem__(self, key: K) -> V:
60
+ v = self._dict[key]
61
+ if isinstance(v, dict):
62
+ return FrozenDict(v) # type: ignore[return-value]
63
+ return v
64
+
65
+ def __setitem__(self, key: K, value: V) -> None:
66
+ raise ValueError("FrozenDict is immutable.")
67
+
68
+ def __contains__(self, key: object) -> bool:
69
+ return key in self._dict
70
+
71
+ def __iter__(self) -> Iterator[K]:
72
+ return iter(self._dict)
73
+
74
+ def __len__(self) -> int:
75
+ return len(self._dict)
76
+
77
+ def __repr__(self) -> str:
78
+ return self.pretty_repr()
79
+
80
+ def __reduce__(self) -> tuple[type["FrozenDict[K, V]"], tuple[dict[K, V]]]:
81
+ return FrozenDict, (self.unfreeze(),)
82
+
83
+ def pretty_repr(self, num_spaces: int = 4) -> str:
84
+ """Returns an indented representation of the nested dictionary."""
85
+
86
+ def pretty_dict(x: Any) -> str: # noqa: ANN401
87
+ if not isinstance(x, dict):
88
+ return repr(x)
89
+ rep = ""
90
+ for key, val in x.items():
91
+ rep += f"{key}: {pretty_dict(val)},\n"
92
+ if rep:
93
+ return "{\n" + _indent(rep, num_spaces) + "}"
94
+ else:
95
+ return "{}"
96
+
97
+ return f"FrozenDict({pretty_dict(self._dict)})"
98
+
99
+ def __hash__(self) -> int:
100
+ if self._hash is None:
101
+ h = 0
102
+ for key, value in self.items():
103
+ h ^= hash((key, value))
104
+ self._hash = h
105
+ return self._hash
106
+
107
+ def copy(self, add_or_replace: Mapping[K, V] = MappingProxyType({})) -> Self:
108
+ return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]
109
+
110
+ def keys(self) -> FrozenKeysView[K]:
111
+ return FrozenKeysView(self)
112
+
113
+ def values(self) -> FrozenValuesView[V]:
114
+ return FrozenValuesView(self)
115
+
116
+ def items(self) -> Iterator[tuple[K, V]]: # type: ignore[override]
117
+ for key in self._dict:
118
+ yield (key, self[key])
119
+
120
+ def pop(self, key: K) -> tuple["FrozenDict[K, V]", V]:
121
+ value = self[key]
122
+ new_dict = dict(self._dict)
123
+ new_dict.pop(key)
124
+ new_self = type(self)(new_dict)
125
+ return new_self, value
126
+
127
+ def unfreeze(self) -> dict[K, V]:
128
+ return unfreeze(self)
129
+
130
+ def tree_flatten_with_keys(self) -> tuple[tuple[tuple[jax.tree_util.DictKey, Any], ...], tuple[K, ...]]:
131
+ sorted_keys = sorted(self._dict)
132
+ return tuple([(jax.tree_util.DictKey(k), self._dict[k]) for k in sorted_keys]), tuple(sorted_keys)
133
+
134
+ @classmethod
135
+ def tree_unflatten(cls, keys: tuple[K, ...], values: tuple[Any, ...]) -> "FrozenDict[K, V]":
136
+ return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
137
+
138
+
139
+ def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
140
+ if isinstance(x, FrozenDict):
141
+ return jax.tree_util.tree_map(lambda y: y, x._dict)
142
+ elif isinstance(x, dict):
143
+ ys = {}
144
+ for key, value in x.items():
145
+ ys[key] = unfreeze(value)
146
+ return ys
147
+ else:
148
+ return x
@@ -0,0 +1,31 @@
1
+ """Defines a hashable array wrapper.
2
+
3
+ Since Jax relies extensively on hashing, and we sometimes want to treat Jax
4
+ arrays as constants, this wrapper lets us ensure that Jax and Numpy arrays can
5
+ be hashed for Jitting.
6
+ """
7
+
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+
11
+
12
+ class HashableArray:
13
+ def __init__(self, array: np.ndarray | jnp.ndarray) -> None:
14
+ if not isinstance(array, (np.ndarray, jnp.ndarray)):
15
+ raise ValueError(f"Expected np.ndarray or jnp.ndarray, got {type(array)}")
16
+ self.array = array
17
+ self._hash: int | None = None
18
+
19
+ def __hash__(self) -> int:
20
+ if self._hash is None:
21
+ self._hash = hash(self.array.tobytes())
22
+ return self._hash
23
+
24
+ def __eq__(self, other: object) -> bool:
25
+ if not isinstance(other, HashableArray):
26
+ return False
27
+ return bool(jnp.array_equal(self.array, other.array))
28
+
29
+
30
+ def hashable_array(array: np.ndarray | jnp.ndarray) -> HashableArray:
31
+ return HashableArray(array)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -63,4 +63,7 @@ xax/utils/pytree.py
63
63
  xax/utils/tensorboard.py
64
64
  xax/utils/text.py
65
65
  xax/utils/data/__init__.py
66
- xax/utils/data/collate.py
66
+ xax/utils/data/collate.py
67
+ xax/utils/types/__init__.py
68
+ xax/utils/types/frozen_dict.py
69
+ xax/utils/types/hashable_array.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes