xax 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.6"
15
+ __version__ = "0.1.8"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -97,8 +97,6 @@ __all__ = [
97
97
  "save_config",
98
98
  "stage_environment",
99
99
  "to_markdown_table",
100
- "HashableArray",
101
- "hashable_array",
102
100
  "jit",
103
101
  "save_jaxpr_dot",
104
102
  "ColoredFormatter",
@@ -132,6 +130,9 @@ __all__ = [
132
130
  "snakecase_to_camelcase",
133
131
  "uncolored",
134
132
  "wrapped",
133
+ "FrozenDict",
134
+ "HashableArray",
135
+ "hashable_array",
135
136
  ]
136
137
 
137
138
  __all__ += [
@@ -161,7 +162,7 @@ if "XLA_FLAGS" in os.environ:
161
162
  # If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
162
163
  # Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
163
164
  if shutil.which("nvidia-smi") is not None:
164
- xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler", "--xla_gpu_enable_triton_gemm"]
165
+ xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler=true", "--xla_gpu_enable_triton_gemm=false"]
165
166
  os.environ["XLA_FLAGS"] = " ".join(xla_flags)
166
167
 
167
168
  # If this flag is set, eagerly imports the entire package (not recommended).
@@ -253,8 +254,6 @@ NAME_MAP: dict[str, str] = {
253
254
  "save_config": "utils.experiments",
254
255
  "stage_environment": "utils.experiments",
255
256
  "to_markdown_table": "utils.experiments",
256
- "HashableArray": "utils.jax",
257
- "hashable_array": "utils.jax",
258
257
  "jit": "utils.jax",
259
258
  "save_jaxpr_dot": "utils.jaxpr",
260
259
  "ColoredFormatter": "utils.logging",
@@ -288,6 +287,9 @@ NAME_MAP: dict[str, str] = {
288
287
  "snakecase_to_camelcase": "utils.text",
289
288
  "uncolored": "utils.text",
290
289
  "wrapped": "utils.text",
290
+ "FrozenDict": "utils.types.frozen_dict",
291
+ "HashableArray": "utils.types.hashable_array",
292
+ "hashable_array": "utils.types.hashable_array",
291
293
  }
292
294
 
293
295
  # Need to manually set some values which can't be auto-generated.
@@ -408,7 +410,7 @@ if IMPORT_ALL or TYPE_CHECKING:
408
410
  stage_environment,
409
411
  to_markdown_table,
410
412
  )
411
- from xax.utils.jax import HashableArray, hashable_array, jit
413
+ from xax.utils.jax import jit
412
414
  from xax.utils.jaxpr import save_jaxpr_dot
413
415
  from xax.utils.logging import (
414
416
  LOG_ERROR_SUMMARY,
@@ -448,5 +450,7 @@ if IMPORT_ALL or TYPE_CHECKING:
448
450
  uncolored,
449
451
  wrapped,
450
452
  )
453
+ from xax.utils.types.frozen_dict import FrozenDict
454
+ from xax.utils.types.hashable_array import HashableArray, hashable_array
451
455
 
452
456
  del TYPE_CHECKING, IMPORT_ALL
xax/nn/equinox.py CHANGED
@@ -72,7 +72,7 @@ def _infer_activation(activation: ActivationFunction) -> Callable:
72
72
  raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
73
73
 
74
74
 
75
- def make_eqx_mlp(hyperparams: MLPHyperParams, key: PRNGKeyArray = jax.random.PRNGKey(0)) -> eqx.nn.MLP:
75
+ def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
76
76
  """Create an Equinox MLP from a set of hyperparameters.
77
77
 
78
78
  Args:
@@ -176,5 +176,5 @@ def load_eqx_mlp(
176
176
  ) -> eqx.nn.MLP:
177
177
  with open(eqx_file, "rb") as f:
178
178
  hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
179
- model = make_eqx_mlp(hyperparams=hyperparams)
179
+ model = make_eqx_mlp(hyperparams=hyperparams, key=jax.random.PRNGKey(0))
180
180
  return eqx.tree_deserialise_leaves(f, model)
xax/nn/export.py CHANGED
@@ -4,18 +4,18 @@ import logging
4
4
  from pathlib import Path
5
5
  from typing import Callable
6
6
 
7
- import flax
8
7
  import jax
9
- import tensorflow as tf
10
- from jax.experimental import jax2tf
11
8
  from jaxtyping import Array, PyTree
12
9
 
13
10
  try:
11
+ import flax
12
+ import tensorflow as tf
13
+ from jax.experimental import jax2tf
14
14
  from orbax.export import ExportManager, JaxModule, ServingConfig
15
15
  except ImportError as e:
16
16
  raise ImportError(
17
- "Please install the package with `orbax` as a dependency, using "
18
- "'xax[export]` to install the required dependencies."
17
+ "In order to export models, please install Xax with export dependencies, "
18
+ "using 'xax[export]` to install the required dependencies."
19
19
  ) from e
20
20
 
21
21
  logger = logging.getLogger(__name__)
@@ -22,7 +22,7 @@ def get_cache_dir() -> str | None:
22
22
  # By default, only cache on MacOS, since Jax caching on Linux is very
23
23
  # prone to NaNs.
24
24
  match sys.platform:
25
- case "darwin":
25
+ case "darwin" | "linux":
26
26
  return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
27
27
  case _:
28
28
  return None
xax/utils/jax.py CHANGED
@@ -138,25 +138,3 @@ def jit(
138
138
  return wrapped
139
139
 
140
140
  return decorator
141
-
142
-
143
- class HashableArray:
144
- def __init__(self, array: np.ndarray | jnp.ndarray) -> None:
145
- if not isinstance(array, (np.ndarray, jnp.ndarray)):
146
- raise ValueError(f"Expected np.ndarray or jnp.ndarray, got {type(array)}")
147
- self.array = array
148
- self._hash: int | None = None
149
-
150
- def __hash__(self) -> int:
151
- if self._hash is None:
152
- self._hash = hash(self.array.tobytes())
153
- return self._hash
154
-
155
- def __eq__(self, other: object) -> bool:
156
- if not isinstance(other, HashableArray):
157
- return False
158
- return bool(jnp.array_equal(self.array, other.array))
159
-
160
-
161
- def hashable_array(array: np.ndarray | jnp.ndarray) -> HashableArray:
162
- return HashableArray(array)
File without changes
@@ -0,0 +1,148 @@
1
+ """Defines a frozen dictionary type.
2
+
3
+ This is mostly taken from Flax - we move it here to avoid having to use Flax as
4
+ a dependency in downstream projects.
5
+ """
6
+
7
+ import collections
8
+ from types import MappingProxyType
9
+ from typing import Any, Iterator, Mapping, Self, TypeVar
10
+
11
+ import jax
12
+
13
+ K = TypeVar("K")
14
+ V = TypeVar("V")
15
+
16
+
17
+ def _prepare_freeze(xs: Any) -> Any: # noqa: ANN401
18
+ """Deep copy unfrozen dicts to make the dictionary FrozenDict safe."""
19
+ if isinstance(xs, FrozenDict):
20
+ return xs._dict # pylint: disable=protected-access
21
+ if not isinstance(xs, dict):
22
+ return xs
23
+ return {key: _prepare_freeze(val) for key, val in xs.items()}
24
+
25
+
26
+ def _indent(x: str, num_spaces: int) -> str:
27
+ indent_str = " " * num_spaces
28
+ lines = x.split("\n")
29
+ assert not lines[-1]
30
+ return "\n".join(indent_str + line for line in lines[:-1]) + "\n"
31
+
32
+
33
+ class FrozenKeysView(collections.abc.KeysView[K]):
34
+ def __repr__(self) -> str:
35
+ return f"frozen_dict_keys({list(self)})"
36
+
37
+
38
+ class FrozenValuesView(collections.abc.ValuesView[V]):
39
+ def __repr__(self) -> str:
40
+ return f"frozen_dict_values({list(self)})"
41
+
42
+
43
+ @jax.tree_util.register_pytree_with_keys_class
44
+ class FrozenDict(Mapping[K, V]):
45
+ """An immutable variant of the Python dict."""
46
+
47
+ __slots__ = ("_dict", "_hash")
48
+
49
+ def __init__(self, *args: Any, __unsafe_skip_copy__: bool = False, **kwargs: Any) -> None: # noqa: ANN401
50
+ # make sure the dict is as
51
+ xs = dict(*args, **kwargs)
52
+ if __unsafe_skip_copy__:
53
+ self._dict = xs
54
+ else:
55
+ self._dict = _prepare_freeze(xs)
56
+
57
+ self._hash: int | None = None
58
+
59
+ def __getitem__(self, key: K) -> V:
60
+ v = self._dict[key]
61
+ if isinstance(v, dict):
62
+ return FrozenDict(v) # type: ignore[return-value]
63
+ return v
64
+
65
+ def __setitem__(self, key: K, value: V) -> None:
66
+ raise ValueError("FrozenDict is immutable.")
67
+
68
+ def __contains__(self, key: object) -> bool:
69
+ return key in self._dict
70
+
71
+ def __iter__(self) -> Iterator[K]:
72
+ return iter(self._dict)
73
+
74
+ def __len__(self) -> int:
75
+ return len(self._dict)
76
+
77
+ def __repr__(self) -> str:
78
+ return self.pretty_repr()
79
+
80
+ def __reduce__(self) -> tuple[type["FrozenDict[K, V]"], tuple[dict[K, V]]]:
81
+ return FrozenDict, (self.unfreeze(),)
82
+
83
+ def pretty_repr(self, num_spaces: int = 4) -> str:
84
+ """Returns an indented representation of the nested dictionary."""
85
+
86
+ def pretty_dict(x: Any) -> str: # noqa: ANN401
87
+ if not isinstance(x, dict):
88
+ return repr(x)
89
+ rep = ""
90
+ for key, val in x.items():
91
+ rep += f"{key}: {pretty_dict(val)},\n"
92
+ if rep:
93
+ return "{\n" + _indent(rep, num_spaces) + "}"
94
+ else:
95
+ return "{}"
96
+
97
+ return f"FrozenDict({pretty_dict(self._dict)})"
98
+
99
+ def __hash__(self) -> int:
100
+ if self._hash is None:
101
+ h = 0
102
+ for key, value in self.items():
103
+ h ^= hash((key, value))
104
+ self._hash = h
105
+ return self._hash
106
+
107
+ def copy(self, add_or_replace: Mapping[K, V] = MappingProxyType({})) -> Self:
108
+ return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]
109
+
110
+ def keys(self) -> FrozenKeysView[K]:
111
+ return FrozenKeysView(self)
112
+
113
+ def values(self) -> FrozenValuesView[V]:
114
+ return FrozenValuesView(self)
115
+
116
+ def items(self) -> Iterator[tuple[K, V]]: # type: ignore[override]
117
+ for key in self._dict:
118
+ yield (key, self[key])
119
+
120
+ def pop(self, key: K) -> tuple["FrozenDict[K, V]", V]:
121
+ value = self[key]
122
+ new_dict = dict(self._dict)
123
+ new_dict.pop(key)
124
+ new_self = type(self)(new_dict)
125
+ return new_self, value
126
+
127
+ def unfreeze(self) -> dict[K, V]:
128
+ return unfreeze(self)
129
+
130
+ def tree_flatten_with_keys(self) -> tuple[tuple[tuple[jax.tree_util.DictKey, Any], ...], tuple[K, ...]]:
131
+ sorted_keys = sorted(self._dict)
132
+ return tuple([(jax.tree_util.DictKey(k), self._dict[k]) for k in sorted_keys]), tuple(sorted_keys)
133
+
134
+ @classmethod
135
+ def tree_unflatten(cls, keys: tuple[K, ...], values: tuple[Any, ...]) -> "FrozenDict[K, V]":
136
+ return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
137
+
138
+
139
+ def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
140
+ if isinstance(x, FrozenDict):
141
+ return jax.tree_util.tree_map(lambda y: y, x._dict)
142
+ elif isinstance(x, dict):
143
+ ys = {}
144
+ for key, value in x.items():
145
+ ys[key] = unfreeze(value)
146
+ return ys
147
+ else:
148
+ return x
@@ -0,0 +1,31 @@
1
+ """Defines a hashable array wrapper.
2
+
3
+ Since Jax relies extensively on hashing, and we sometimes want to treat Jax
4
+ arrays as constants, this wrapper lets us ensure that Jax and Numpy arrays can
5
+ be hashed for Jitting.
6
+ """
7
+
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+
11
+
12
+ class HashableArray:
13
+ def __init__(self, array: np.ndarray | jnp.ndarray) -> None:
14
+ if not isinstance(array, (np.ndarray, jnp.ndarray)):
15
+ raise ValueError(f"Expected np.ndarray or jnp.ndarray, got {type(array)}")
16
+ self.array = array
17
+ self._hash: int | None = None
18
+
19
+ def __hash__(self) -> int:
20
+ if self._hash is None:
21
+ self._hash = hash(self.array.tobytes())
22
+ return self._hash
23
+
24
+ def __eq__(self, other: object) -> bool:
25
+ if not isinstance(other, HashableArray):
26
+ return False
27
+ return bool(jnp.array_equal(self.array, other.array))
28
+
29
+
30
+ def hashable_array(array: np.ndarray | jnp.ndarray) -> HashableArray:
31
+ return HashableArray(array)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=rjqydWhxQVUAj3lXgFpzj4iLFOdDJGHArfAH7_QSkhk,13504
1
+ xax/__init__.py,sha256=9WNjoeAF7enu7YXQqshpVG1FucGdSkxwrRa-ELDDuUs,13713
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -7,8 +7,8 @@ xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
7
7
  xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
- xax/nn/equinox.py,sha256=1Ck6ycz76dhit2LHX4y2lp3WJSPsDuRt7TK7AxxQhww,4837
11
- xax/nn/export.py,sha256=bu2m-4FDnadEhXDb9zM6SgOZvsf5p4xiee1sFZyNF7c,5510
10
+ xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
+ xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
13
  xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
14
14
  xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
@@ -31,7 +31,7 @@ xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCn
31
31
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
32
32
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
33
33
  xax/task/mixins/checkpointing.py,sha256=a6tVyISsDIz68rrhb1rAh3rjQlqkDVJCmSBmETQrnRM,8480
34
- xax/task/mixins/compile.py,sha256=9pVJEUvizu6-6tq0HaMtHGNSi9Yk_mxNyqBFimcfwL0,3683
34
+ xax/task/mixins/compile.py,sha256=8jEdlGs-a14N_CwZA3Rxe461MT83dyIDr3Z56VkjviQ,3693
35
35
  xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
36
36
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
37
37
  xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
@@ -43,7 +43,7 @@ xax/task/mixins/train.py,sha256=vsH_QpyrThlh9AzWnyvDJv58Y8U_516oi8gmMq_0iMg,2233
43
43
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
45
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
46
- xax/utils/jax.py,sha256=eObvWt2DraCs2IMDZSdQ0rRk8tA3P5XBlF_UeVq7Aro,5480
46
+ xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
47
47
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
48
48
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
49
49
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
@@ -53,8 +53,11 @@ xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,1705
53
53
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
54
54
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
55
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
56
- xax-0.1.6.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
57
- xax-0.1.6.dist-info/METADATA,sha256=vKxhuOt02ALjFV9fAt-rPVTwvqX4uNr_shL1DGEotA4,1877
58
- xax-0.1.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
59
- xax-0.1.6.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
60
- xax-0.1.6.dist-info/RECORD,,
56
+ xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
+ xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
58
+ xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
59
+ xax-0.1.8.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
60
+ xax-0.1.8.dist-info/METADATA,sha256=wnBSNRByXJzgQPuZqNWooidfFdqcT4w8gbwlBgzbJk8,1877
61
+ xax-0.1.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
62
+ xax-0.1.8.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
63
+ xax-0.1.8.dist-info/RECORD,,
File without changes