xax 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +11 -18
- xax/nn/equinox.py +2 -2
- xax/nn/export.py +5 -5
- xax/task/mixins/compile.py +1 -1
- xax/utils/jax.py +0 -22
- xax/utils/types/__init__.py +0 -0
- xax/utils/types/frozen_dict.py +148 -0
- xax/utils/types/hashable_array.py +31 -0
- {xax-0.1.6.dist-info → xax-0.1.7.dist-info}/METADATA +1 -1
- {xax-0.1.6.dist-info → xax-0.1.7.dist-info}/RECORD +13 -10
- {xax-0.1.6.dist-info → xax-0.1.7.dist-info}/WHEEL +0 -0
- {xax-0.1.6.dist-info → xax-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.6.dist-info → xax-0.1.7.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.7"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -41,9 +41,6 @@ __all__ = [
|
|
41
41
|
"load_eqx_mlp",
|
42
42
|
"make_eqx_mlp",
|
43
43
|
"save_eqx",
|
44
|
-
"export",
|
45
|
-
"export_flax",
|
46
|
-
"export_with_params",
|
47
44
|
"euler_to_quat",
|
48
45
|
"get_projected_gravity_vector_from_quat",
|
49
46
|
"quat_to_euler",
|
@@ -97,8 +94,6 @@ __all__ = [
|
|
97
94
|
"save_config",
|
98
95
|
"stage_environment",
|
99
96
|
"to_markdown_table",
|
100
|
-
"HashableArray",
|
101
|
-
"hashable_array",
|
102
97
|
"jit",
|
103
98
|
"save_jaxpr_dot",
|
104
99
|
"ColoredFormatter",
|
@@ -132,6 +127,9 @@ __all__ = [
|
|
132
127
|
"snakecase_to_camelcase",
|
133
128
|
"uncolored",
|
134
129
|
"wrapped",
|
130
|
+
"FrozenDict",
|
131
|
+
"HashableArray",
|
132
|
+
"hashable_array",
|
135
133
|
]
|
136
134
|
|
137
135
|
__all__ += [
|
@@ -161,7 +159,7 @@ if "XLA_FLAGS" in os.environ:
|
|
161
159
|
# If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
|
162
160
|
# Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
|
163
161
|
if shutil.which("nvidia-smi") is not None:
|
164
|
-
xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler", "--xla_gpu_enable_triton_gemm"]
|
162
|
+
xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler=true", "--xla_gpu_enable_triton_gemm=false"]
|
165
163
|
os.environ["XLA_FLAGS"] = " ".join(xla_flags)
|
166
164
|
|
167
165
|
# If this flag is set, eagerly imports the entire package (not recommended).
|
@@ -197,9 +195,6 @@ NAME_MAP: dict[str, str] = {
|
|
197
195
|
"load_eqx_mlp": "nn.equinox",
|
198
196
|
"make_eqx_mlp": "nn.equinox",
|
199
197
|
"save_eqx": "nn.equinox",
|
200
|
-
"export": "nn.export",
|
201
|
-
"export_flax": "nn.export",
|
202
|
-
"export_with_params": "nn.export",
|
203
198
|
"euler_to_quat": "nn.geom",
|
204
199
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
205
200
|
"quat_to_euler": "nn.geom",
|
@@ -253,8 +248,6 @@ NAME_MAP: dict[str, str] = {
|
|
253
248
|
"save_config": "utils.experiments",
|
254
249
|
"stage_environment": "utils.experiments",
|
255
250
|
"to_markdown_table": "utils.experiments",
|
256
|
-
"HashableArray": "utils.jax",
|
257
|
-
"hashable_array": "utils.jax",
|
258
251
|
"jit": "utils.jax",
|
259
252
|
"save_jaxpr_dot": "utils.jaxpr",
|
260
253
|
"ColoredFormatter": "utils.logging",
|
@@ -288,6 +281,9 @@ NAME_MAP: dict[str, str] = {
|
|
288
281
|
"snakecase_to_camelcase": "utils.text",
|
289
282
|
"uncolored": "utils.text",
|
290
283
|
"wrapped": "utils.text",
|
284
|
+
"FrozenDict": "utils.types.frozen_dict",
|
285
|
+
"HashableArray": "utils.types.hashable_array",
|
286
|
+
"hashable_array": "utils.types.hashable_array",
|
291
287
|
}
|
292
288
|
|
293
289
|
# Need to manually set some values which can't be auto-generated.
|
@@ -352,11 +348,6 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
352
348
|
make_eqx_mlp,
|
353
349
|
save_eqx,
|
354
350
|
)
|
355
|
-
from xax.nn.export import (
|
356
|
-
export,
|
357
|
-
export_flax,
|
358
|
-
export_with_params,
|
359
|
-
)
|
360
351
|
from xax.nn.geom import (
|
361
352
|
euler_to_quat,
|
362
353
|
get_projected_gravity_vector_from_quat,
|
@@ -408,7 +399,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
408
399
|
stage_environment,
|
409
400
|
to_markdown_table,
|
410
401
|
)
|
411
|
-
from xax.utils.jax import
|
402
|
+
from xax.utils.jax import jit
|
412
403
|
from xax.utils.jaxpr import save_jaxpr_dot
|
413
404
|
from xax.utils.logging import (
|
414
405
|
LOG_ERROR_SUMMARY,
|
@@ -448,5 +439,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
448
439
|
uncolored,
|
449
440
|
wrapped,
|
450
441
|
)
|
442
|
+
from xax.utils.types.frozen_dict import FrozenDict
|
443
|
+
from xax.utils.types.hashable_array import HashableArray, hashable_array
|
451
444
|
|
452
445
|
del TYPE_CHECKING, IMPORT_ALL
|
xax/nn/equinox.py
CHANGED
@@ -72,7 +72,7 @@ def _infer_activation(activation: ActivationFunction) -> Callable:
|
|
72
72
|
raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
|
73
73
|
|
74
74
|
|
75
|
-
def make_eqx_mlp(hyperparams: MLPHyperParams, key: PRNGKeyArray
|
75
|
+
def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
|
76
76
|
"""Create an Equinox MLP from a set of hyperparameters.
|
77
77
|
|
78
78
|
Args:
|
@@ -176,5 +176,5 @@ def load_eqx_mlp(
|
|
176
176
|
) -> eqx.nn.MLP:
|
177
177
|
with open(eqx_file, "rb") as f:
|
178
178
|
hyperparams = json.loads(f.readline().decode(encoding="utf-8"))
|
179
|
-
model = make_eqx_mlp(hyperparams=hyperparams)
|
179
|
+
model = make_eqx_mlp(hyperparams=hyperparams, key=jax.random.PRNGKey(0))
|
180
180
|
return eqx.tree_deserialise_leaves(f, model)
|
xax/nn/export.py
CHANGED
@@ -4,18 +4,18 @@ import logging
|
|
4
4
|
from pathlib import Path
|
5
5
|
from typing import Callable
|
6
6
|
|
7
|
-
import flax
|
8
7
|
import jax
|
9
|
-
import tensorflow as tf
|
10
|
-
from jax.experimental import jax2tf
|
11
8
|
from jaxtyping import Array, PyTree
|
12
9
|
|
13
10
|
try:
|
11
|
+
import flax
|
12
|
+
import tensorflow as tf
|
13
|
+
from jax.experimental import jax2tf
|
14
14
|
from orbax.export import ExportManager, JaxModule, ServingConfig
|
15
15
|
except ImportError as e:
|
16
16
|
raise ImportError(
|
17
|
-
"
|
18
|
-
"'xax[export]` to install the required dependencies."
|
17
|
+
"In order to export models, please install Xax with export dependencies, "
|
18
|
+
"using 'xax[export]` to install the required dependencies."
|
19
19
|
) from e
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
xax/task/mixins/compile.py
CHANGED
@@ -22,7 +22,7 @@ def get_cache_dir() -> str | None:
|
|
22
22
|
# By default, only cache on MacOS, since Jax caching on Linux is very
|
23
23
|
# prone to NaNs.
|
24
24
|
match sys.platform:
|
25
|
-
case "darwin":
|
25
|
+
case "darwin" | "linux":
|
26
26
|
return str((Path.home() / ".cache" / "jax" / "jaxcache").resolve())
|
27
27
|
case _:
|
28
28
|
return None
|
xax/utils/jax.py
CHANGED
@@ -138,25 +138,3 @@ def jit(
|
|
138
138
|
return wrapped
|
139
139
|
|
140
140
|
return decorator
|
141
|
-
|
142
|
-
|
143
|
-
class HashableArray:
|
144
|
-
def __init__(self, array: np.ndarray | jnp.ndarray) -> None:
|
145
|
-
if not isinstance(array, (np.ndarray, jnp.ndarray)):
|
146
|
-
raise ValueError(f"Expected np.ndarray or jnp.ndarray, got {type(array)}")
|
147
|
-
self.array = array
|
148
|
-
self._hash: int | None = None
|
149
|
-
|
150
|
-
def __hash__(self) -> int:
|
151
|
-
if self._hash is None:
|
152
|
-
self._hash = hash(self.array.tobytes())
|
153
|
-
return self._hash
|
154
|
-
|
155
|
-
def __eq__(self, other: object) -> bool:
|
156
|
-
if not isinstance(other, HashableArray):
|
157
|
-
return False
|
158
|
-
return bool(jnp.array_equal(self.array, other.array))
|
159
|
-
|
160
|
-
|
161
|
-
def hashable_array(array: np.ndarray | jnp.ndarray) -> HashableArray:
|
162
|
-
return HashableArray(array)
|
File without changes
|
@@ -0,0 +1,148 @@
|
|
1
|
+
"""Defines a frozen dictionary type.
|
2
|
+
|
3
|
+
This is mostly taken from Flax - we move it here to avoid having to use Flax as
|
4
|
+
a dependency in downstream projects.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import collections
|
8
|
+
from types import MappingProxyType
|
9
|
+
from typing import Any, Iterator, Mapping, Self, TypeVar
|
10
|
+
|
11
|
+
import jax
|
12
|
+
|
13
|
+
K = TypeVar("K")
|
14
|
+
V = TypeVar("V")
|
15
|
+
|
16
|
+
|
17
|
+
def _prepare_freeze(xs: Any) -> Any: # noqa: ANN401
|
18
|
+
"""Deep copy unfrozen dicts to make the dictionary FrozenDict safe."""
|
19
|
+
if isinstance(xs, FrozenDict):
|
20
|
+
return xs._dict # pylint: disable=protected-access
|
21
|
+
if not isinstance(xs, dict):
|
22
|
+
return xs
|
23
|
+
return {key: _prepare_freeze(val) for key, val in xs.items()}
|
24
|
+
|
25
|
+
|
26
|
+
def _indent(x: str, num_spaces: int) -> str:
|
27
|
+
indent_str = " " * num_spaces
|
28
|
+
lines = x.split("\n")
|
29
|
+
assert not lines[-1]
|
30
|
+
return "\n".join(indent_str + line for line in lines[:-1]) + "\n"
|
31
|
+
|
32
|
+
|
33
|
+
class FrozenKeysView(collections.abc.KeysView[K]):
|
34
|
+
def __repr__(self) -> str:
|
35
|
+
return f"frozen_dict_keys({list(self)})"
|
36
|
+
|
37
|
+
|
38
|
+
class FrozenValuesView(collections.abc.ValuesView[V]):
|
39
|
+
def __repr__(self) -> str:
|
40
|
+
return f"frozen_dict_values({list(self)})"
|
41
|
+
|
42
|
+
|
43
|
+
@jax.tree_util.register_pytree_with_keys_class
|
44
|
+
class FrozenDict(Mapping[K, V]):
|
45
|
+
"""An immutable variant of the Python dict."""
|
46
|
+
|
47
|
+
__slots__ = ("_dict", "_hash")
|
48
|
+
|
49
|
+
def __init__(self, *args: Any, __unsafe_skip_copy__: bool = False, **kwargs: Any) -> None: # noqa: ANN401
|
50
|
+
# make sure the dict is as
|
51
|
+
xs = dict(*args, **kwargs)
|
52
|
+
if __unsafe_skip_copy__:
|
53
|
+
self._dict = xs
|
54
|
+
else:
|
55
|
+
self._dict = _prepare_freeze(xs)
|
56
|
+
|
57
|
+
self._hash: int | None = None
|
58
|
+
|
59
|
+
def __getitem__(self, key: K) -> V:
|
60
|
+
v = self._dict[key]
|
61
|
+
if isinstance(v, dict):
|
62
|
+
return FrozenDict(v) # type: ignore[return-value]
|
63
|
+
return v
|
64
|
+
|
65
|
+
def __setitem__(self, key: K, value: V) -> None:
|
66
|
+
raise ValueError("FrozenDict is immutable.")
|
67
|
+
|
68
|
+
def __contains__(self, key: object) -> bool:
|
69
|
+
return key in self._dict
|
70
|
+
|
71
|
+
def __iter__(self) -> Iterator[K]:
|
72
|
+
return iter(self._dict)
|
73
|
+
|
74
|
+
def __len__(self) -> int:
|
75
|
+
return len(self._dict)
|
76
|
+
|
77
|
+
def __repr__(self) -> str:
|
78
|
+
return self.pretty_repr()
|
79
|
+
|
80
|
+
def __reduce__(self) -> tuple[type["FrozenDict[K, V]"], tuple[dict[K, V]]]:
|
81
|
+
return FrozenDict, (self.unfreeze(),)
|
82
|
+
|
83
|
+
def pretty_repr(self, num_spaces: int = 4) -> str:
|
84
|
+
"""Returns an indented representation of the nested dictionary."""
|
85
|
+
|
86
|
+
def pretty_dict(x: Any) -> str: # noqa: ANN401
|
87
|
+
if not isinstance(x, dict):
|
88
|
+
return repr(x)
|
89
|
+
rep = ""
|
90
|
+
for key, val in x.items():
|
91
|
+
rep += f"{key}: {pretty_dict(val)},\n"
|
92
|
+
if rep:
|
93
|
+
return "{\n" + _indent(rep, num_spaces) + "}"
|
94
|
+
else:
|
95
|
+
return "{}"
|
96
|
+
|
97
|
+
return f"FrozenDict({pretty_dict(self._dict)})"
|
98
|
+
|
99
|
+
def __hash__(self) -> int:
|
100
|
+
if self._hash is None:
|
101
|
+
h = 0
|
102
|
+
for key, value in self.items():
|
103
|
+
h ^= hash((key, value))
|
104
|
+
self._hash = h
|
105
|
+
return self._hash
|
106
|
+
|
107
|
+
def copy(self, add_or_replace: Mapping[K, V] = MappingProxyType({})) -> Self:
|
108
|
+
return type(self)({**self, **unfreeze(add_or_replace)}) # type: ignore[arg-type]
|
109
|
+
|
110
|
+
def keys(self) -> FrozenKeysView[K]:
|
111
|
+
return FrozenKeysView(self)
|
112
|
+
|
113
|
+
def values(self) -> FrozenValuesView[V]:
|
114
|
+
return FrozenValuesView(self)
|
115
|
+
|
116
|
+
def items(self) -> Iterator[tuple[K, V]]: # type: ignore[override]
|
117
|
+
for key in self._dict:
|
118
|
+
yield (key, self[key])
|
119
|
+
|
120
|
+
def pop(self, key: K) -> tuple["FrozenDict[K, V]", V]:
|
121
|
+
value = self[key]
|
122
|
+
new_dict = dict(self._dict)
|
123
|
+
new_dict.pop(key)
|
124
|
+
new_self = type(self)(new_dict)
|
125
|
+
return new_self, value
|
126
|
+
|
127
|
+
def unfreeze(self) -> dict[K, V]:
|
128
|
+
return unfreeze(self)
|
129
|
+
|
130
|
+
def tree_flatten_with_keys(self) -> tuple[tuple[tuple[jax.tree_util.DictKey, Any], ...], tuple[K, ...]]:
|
131
|
+
sorted_keys = sorted(self._dict)
|
132
|
+
return tuple([(jax.tree_util.DictKey(k), self._dict[k]) for k in sorted_keys]), tuple(sorted_keys)
|
133
|
+
|
134
|
+
@classmethod
|
135
|
+
def tree_unflatten(cls, keys: tuple[K, ...], values: tuple[Any, ...]) -> "FrozenDict[K, V]":
|
136
|
+
return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
|
137
|
+
|
138
|
+
|
139
|
+
def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
|
140
|
+
if isinstance(x, FrozenDict):
|
141
|
+
return jax.tree_util.tree_map(lambda y: y, x._dict)
|
142
|
+
elif isinstance(x, dict):
|
143
|
+
ys = {}
|
144
|
+
for key, value in x.items():
|
145
|
+
ys[key] = unfreeze(value)
|
146
|
+
return ys
|
147
|
+
else:
|
148
|
+
return x
|
@@ -0,0 +1,31 @@
|
|
1
|
+
"""Defines a hashable array wrapper.
|
2
|
+
|
3
|
+
Since Jax relies extensively on hashing, and we sometimes want to treat Jax
|
4
|
+
arrays as constants, this wrapper lets us ensure that Jax and Numpy arrays can
|
5
|
+
be hashed for Jitting.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
|
12
|
+
class HashableArray:
|
13
|
+
def __init__(self, array: np.ndarray | jnp.ndarray) -> None:
|
14
|
+
if not isinstance(array, (np.ndarray, jnp.ndarray)):
|
15
|
+
raise ValueError(f"Expected np.ndarray or jnp.ndarray, got {type(array)}")
|
16
|
+
self.array = array
|
17
|
+
self._hash: int | None = None
|
18
|
+
|
19
|
+
def __hash__(self) -> int:
|
20
|
+
if self._hash is None:
|
21
|
+
self._hash = hash(self.array.tobytes())
|
22
|
+
return self._hash
|
23
|
+
|
24
|
+
def __eq__(self, other: object) -> bool:
|
25
|
+
if not isinstance(other, HashableArray):
|
26
|
+
return False
|
27
|
+
return bool(jnp.array_equal(self.array, other.array))
|
28
|
+
|
29
|
+
|
30
|
+
def hashable_array(array: np.ndarray | jnp.ndarray) -> HashableArray:
|
31
|
+
return HashableArray(array)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=Yxbt78w3k1Y827EjR_kiT6_01FyRrJKn1U5WkC5I2Ik,13453
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -7,8 +7,8 @@ xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
|
7
7
|
xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
|
-
xax/nn/equinox.py,sha256=
|
11
|
-
xax/nn/export.py,sha256=
|
10
|
+
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
11
|
+
xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
|
12
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
13
13
|
xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
|
14
14
|
xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
|
@@ -31,7 +31,7 @@ xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCn
|
|
31
31
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
32
32
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
33
33
|
xax/task/mixins/checkpointing.py,sha256=a6tVyISsDIz68rrhb1rAh3rjQlqkDVJCmSBmETQrnRM,8480
|
34
|
-
xax/task/mixins/compile.py,sha256=
|
34
|
+
xax/task/mixins/compile.py,sha256=8jEdlGs-a14N_CwZA3Rxe461MT83dyIDr3Z56VkjviQ,3693
|
35
35
|
xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
|
36
36
|
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
37
37
|
xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
|
@@ -43,7 +43,7 @@ xax/task/mixins/train.py,sha256=vsH_QpyrThlh9AzWnyvDJv58Y8U_516oi8gmMq_0iMg,2233
|
|
43
43
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
45
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
46
|
-
xax/utils/jax.py,sha256=
|
46
|
+
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
47
47
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
48
48
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
49
49
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
@@ -53,8 +53,11 @@ xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,1705
|
|
53
53
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
54
54
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
55
55
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
56
|
-
xax
|
57
|
-
xax
|
58
|
-
xax
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
56
|
+
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
|
+
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
58
|
+
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
59
|
+
xax-0.1.7.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
60
|
+
xax-0.1.7.dist-info/METADATA,sha256=KWlXilX8eGmFEi9-FR-QyX3pyUEEbT2dCjx8tFA1yKU,1877
|
61
|
+
xax-0.1.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
62
|
+
xax-0.1.7.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
63
|
+
xax-0.1.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|