xax 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +94 -4
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +26 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +97 -2
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -34
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jaxpr.py +77 -0
- xax/utils/pytree.py +189 -1
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/METADATA +23 -4
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/RECORD +26 -21
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/WHEEL +1 -1
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.7.dist-info → xax-0.1.0.dist-info}/top_level.txt +0 -0
xax/task/script.py
CHANGED
@@ -17,8 +17,6 @@ from xax.task.mixins import (
|
|
17
17
|
ProcessMixin,
|
18
18
|
RunnableConfig,
|
19
19
|
RunnableMixin,
|
20
|
-
StepContextConfig,
|
21
|
-
StepContextMixin,
|
22
20
|
)
|
23
21
|
|
24
22
|
|
@@ -28,7 +26,6 @@ class ScriptConfig(
|
|
28
26
|
GPUStatsConfig,
|
29
27
|
ProcessConfig,
|
30
28
|
LoggerConfig,
|
31
|
-
StepContextConfig,
|
32
29
|
ArtifactsConfig,
|
33
30
|
RunnableConfig,
|
34
31
|
BaseConfig,
|
@@ -44,7 +41,6 @@ class Script(
|
|
44
41
|
GPUStatsMixin[ConfigT],
|
45
42
|
ProcessMixin[ConfigT],
|
46
43
|
LoggerMixin[ConfigT],
|
47
|
-
StepContextMixin[ConfigT],
|
48
44
|
ArtifactsMixin[ConfigT],
|
49
45
|
RunnableMixin[ConfigT],
|
50
46
|
BaseTask[ConfigT],
|
xax/utils/debugging.py
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
"""Defines some useful Jax debugging utilities."""
|
2
|
+
|
3
|
+
from collections import deque
|
4
|
+
from collections.abc import Iterable, Mapping
|
5
|
+
from typing import Any, Callable, Deque
|
6
|
+
|
7
|
+
from jaxtyping import Array
|
8
|
+
|
9
|
+
|
10
|
+
def get_named_leaves(
|
11
|
+
obj: Any, # noqa: ANN401
|
12
|
+
is_leaf: Callable[[Any], bool] = lambda x: isinstance(x, Array), # noqa: ANN401
|
13
|
+
max_depth: int = 100,
|
14
|
+
) -> list[tuple[str, Any]]: # noqa: ANN401
|
15
|
+
ret: list[tuple[str, Any]] = []
|
16
|
+
q: Deque[tuple[int, str, Any]] = deque() # noqa: ANN401
|
17
|
+
q.append((0, "", obj))
|
18
|
+
|
19
|
+
while q:
|
20
|
+
depth, name, node = q.popleft()
|
21
|
+
|
22
|
+
if depth > max_depth:
|
23
|
+
continue
|
24
|
+
|
25
|
+
if hasattr(node, "__dict__") and isinstance(node.__dict__, Mapping):
|
26
|
+
for cname, cnode in node.__dict__.items():
|
27
|
+
gname = f"{name}.{cname}" if name else cname
|
28
|
+
if is_leaf(cnode):
|
29
|
+
ret.append((gname, cnode))
|
30
|
+
else:
|
31
|
+
q.append((depth + 1, gname, cnode))
|
32
|
+
|
33
|
+
elif isinstance(node, Mapping):
|
34
|
+
for cname, cnode in node.items():
|
35
|
+
gname = f"{name}.{cname}" if name else cname
|
36
|
+
if is_leaf(cnode):
|
37
|
+
ret.append((gname, cnode))
|
38
|
+
else:
|
39
|
+
q.append((depth + 1, gname, cnode))
|
40
|
+
|
41
|
+
elif isinstance(node, Iterable):
|
42
|
+
for i, cnode in enumerate(node):
|
43
|
+
gname = f"{name}.{i}" if name else str(i)
|
44
|
+
if is_leaf(cnode):
|
45
|
+
ret.append((gname, cnode))
|
46
|
+
else:
|
47
|
+
q.append((depth + 1, gname, cnode))
|
48
|
+
|
49
|
+
return ret
|
xax/utils/experiments.py
CHANGED
@@ -23,7 +23,8 @@ import urllib.request
|
|
23
23
|
import warnings
|
24
24
|
from abc import ABC, abstractmethod
|
25
25
|
from pathlib import Path
|
26
|
-
from
|
26
|
+
from types import TracebackType
|
27
|
+
from typing import Any, Iterator, Self, TypeVar, cast
|
27
28
|
from urllib.parse import urlparse
|
28
29
|
|
29
30
|
import git
|
@@ -116,19 +117,19 @@ class StateTimer:
|
|
116
117
|
logs: dict[str, dict[str, int | float]] = {}
|
117
118
|
|
118
119
|
# Logs step statistics.
|
119
|
-
logs["
|
120
|
+
logs["⌛ steps"] = {
|
120
121
|
"total": self.step_timer.steps,
|
121
122
|
"per-second": self.step_timer.steps_per_second,
|
122
123
|
}
|
123
124
|
|
124
125
|
# Logs sample statistics.
|
125
|
-
logs["
|
126
|
+
logs["⌛ samples"] = {
|
126
127
|
"total": self.sample_timer.steps,
|
127
128
|
"per-second": self.sample_timer.steps_per_second,
|
128
129
|
}
|
129
130
|
|
130
131
|
# Logs full iteration statistics.
|
131
|
-
logs["
|
132
|
+
logs["⌛ dt"] = {
|
132
133
|
"iter": self.iter_timer.iter_seconds,
|
133
134
|
}
|
134
135
|
|
@@ -147,6 +148,24 @@ class IntervalTicker:
|
|
147
148
|
return False
|
148
149
|
|
149
150
|
|
151
|
+
class ContextTimer:
|
152
|
+
def __init__(self) -> None:
|
153
|
+
self.start_time = 0.0
|
154
|
+
self.elapsed_time = 0.0
|
155
|
+
|
156
|
+
def __enter__(self) -> Self:
|
157
|
+
self.start_time = time.time()
|
158
|
+
return self
|
159
|
+
|
160
|
+
def __exit__(
|
161
|
+
self,
|
162
|
+
exc_type: type[BaseException] | None,
|
163
|
+
exc_value: BaseException | None,
|
164
|
+
traceback: TracebackType | None,
|
165
|
+
) -> None:
|
166
|
+
self.elapsed_time = time.time() - self.start_time
|
167
|
+
|
168
|
+
|
150
169
|
def abs_path(path: str) -> str:
|
151
170
|
return str(Path(path).resolve())
|
152
171
|
|
xax/utils/jaxpr.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
"""Visualize JAXPR."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.core
|
7
|
+
|
8
|
+
|
9
|
+
def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
|
10
|
+
"""Save the JAXPR to a DOT file.
|
11
|
+
|
12
|
+
Example usage:
|
13
|
+
|
14
|
+
grad_fn_jaxpr = jax.make_jaxpr(loss_fn)(variables)
|
15
|
+
save_jaxpr_dot(grad_fn_jaxpr, "grad_fn_jaxpr.dot")
|
16
|
+
|
17
|
+
Then, you can visualize the JAXPR using Graphviz:
|
18
|
+
|
19
|
+
dot -Tpng grad_fn_jaxpr.dot > grad_fn_jaxpr.png
|
20
|
+
|
21
|
+
Args:
|
22
|
+
closed_jaxpr: The closed JAXPR to save.
|
23
|
+
filename: The filename to save the JAXPR to.
|
24
|
+
"""
|
25
|
+
if hasattr(closed_jaxpr, "jaxpr"):
|
26
|
+
jaxpr = closed_jaxpr.jaxpr
|
27
|
+
else:
|
28
|
+
jaxpr = closed_jaxpr
|
29
|
+
|
30
|
+
with open(filename, "w") as f:
|
31
|
+
f.write("digraph Jaxpr {\n")
|
32
|
+
|
33
|
+
var_names: dict[jax.core.Var, str] = {}
|
34
|
+
var_count = 0
|
35
|
+
|
36
|
+
def get_var_name(var: jax.core.Var) -> str:
|
37
|
+
"""Get a unique name for a variable."""
|
38
|
+
nonlocal var_names, var_count
|
39
|
+
|
40
|
+
# Handle Literal objects specially since they're not hashable
|
41
|
+
if isinstance(var, jax.core.Literal):
|
42
|
+
# Create a name based on the literal value
|
43
|
+
name = f"lit_{var.val}"
|
44
|
+
return name
|
45
|
+
|
46
|
+
# For other variables
|
47
|
+
if var not in var_names:
|
48
|
+
name = f"var_{var_count}"
|
49
|
+
var_names[var] = name
|
50
|
+
var_count += 1
|
51
|
+
return var_names[var]
|
52
|
+
|
53
|
+
for var in jaxpr.invars:
|
54
|
+
node_name = get_var_name(var)
|
55
|
+
f.write(f' {node_name} [label="{node_name}\\n(input)"];\n')
|
56
|
+
|
57
|
+
eq_count = 0
|
58
|
+
for eq in jaxpr.eqns:
|
59
|
+
eq_node = f"eq{eq_count}"
|
60
|
+
label = f"{eq.primitive.name}"
|
61
|
+
f.write(f' {eq_node} [shape=box, label="{label}"];\n')
|
62
|
+
|
63
|
+
for invar in eq.invars:
|
64
|
+
var_name = get_var_name(invar)
|
65
|
+
f.write(f" {var_name} -> {eq_node};\n")
|
66
|
+
|
67
|
+
for outvar in eq.outvars:
|
68
|
+
var_name = get_var_name(outvar)
|
69
|
+
f.write(f" {eq_node} -> {var_name};\n")
|
70
|
+
|
71
|
+
eq_count += 1
|
72
|
+
|
73
|
+
for var in jaxpr.outvars:
|
74
|
+
node_name = get_var_name(var)
|
75
|
+
f.write(f' {node_name} [peripheries=2, label="{node_name}\\n(output)"];\n')
|
76
|
+
|
77
|
+
f.write("}\n")
|
xax/utils/pytree.py
CHANGED
@@ -4,7 +4,7 @@ import chex
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
6
|
from jax import Array
|
7
|
-
from jaxtyping import PyTree
|
7
|
+
from jaxtyping import PRNGKeyArray, PyTree
|
8
8
|
|
9
9
|
|
10
10
|
def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
@@ -12,6 +12,14 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
|
12
12
|
|
13
13
|
For multi-dimensional arrays, this slices only along the first dimension
|
14
14
|
and keeps all other dimensions intact.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
x: The array to slice.
|
18
|
+
start: The start index of the slice.
|
19
|
+
slice_length: The length of the slice.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
The sliced array.
|
15
23
|
"""
|
16
24
|
chex.assert_shape(start, ())
|
17
25
|
chex.assert_shape(slice_length, ())
|
@@ -38,6 +46,21 @@ def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
|
|
38
46
|
return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
|
39
47
|
|
40
48
|
|
49
|
+
def pytree_has_nans(pytree: PyTree) -> Array:
|
50
|
+
"""Check if a pytree has any NaNs."""
|
51
|
+
has_nans = jax.tree_util.tree_reduce(
|
52
|
+
lambda a, b: jnp.logical_or(a, b),
|
53
|
+
jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree),
|
54
|
+
)
|
55
|
+
return has_nans
|
56
|
+
|
57
|
+
|
58
|
+
def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
|
59
|
+
"""Update a pytree based on a condition."""
|
60
|
+
# Tricky, need use tree_map because where expects array leafs.
|
61
|
+
return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), new, original)
|
62
|
+
|
63
|
+
|
41
64
|
def compute_nan_ratio(pytree: PyTree) -> Array:
|
42
65
|
"""Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
|
43
66
|
nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
|
@@ -48,3 +71,168 @@ def compute_nan_ratio(pytree: PyTree) -> Array:
|
|
48
71
|
overall_nan_ratio = jnp.array(total_nans / total_elements)
|
49
72
|
|
50
73
|
return overall_nan_ratio
|
74
|
+
|
75
|
+
|
76
|
+
def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
|
77
|
+
"""Reshuffle a pytree along the leading dimensions.
|
78
|
+
|
79
|
+
This function reshuffles the data along the leading dimensions specified by batch_shape.
|
80
|
+
Assumes the dimensions to shuffle are the leading ones.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
data: A pytree with arrays.
|
84
|
+
batch_shape: A tuple of integers specifying the size of each leading dimension to shuffle.
|
85
|
+
rng: A JAX PRNG key.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
A new pytree with the same structure but with the data reshuffled along
|
89
|
+
the leading dimensions.
|
90
|
+
"""
|
91
|
+
# Create a permutation for the flattened batch dimensions
|
92
|
+
flat_size = 1
|
93
|
+
for dim in batch_shape:
|
94
|
+
flat_size *= dim
|
95
|
+
|
96
|
+
perm = jax.random.permutation(rng, flat_size)
|
97
|
+
|
98
|
+
def permute_array(x: Array) -> Array:
|
99
|
+
if not isinstance(x, jnp.ndarray):
|
100
|
+
return x
|
101
|
+
|
102
|
+
# Check if the array has enough dimensions
|
103
|
+
if len(x.shape) < len(batch_shape):
|
104
|
+
return x
|
105
|
+
|
106
|
+
# Check if the dimensions match the batch_shape
|
107
|
+
for i, dim in enumerate(batch_shape):
|
108
|
+
if x.shape[i] != dim:
|
109
|
+
return x
|
110
|
+
|
111
|
+
# Reshape to flatten the batch dimensions
|
112
|
+
orig_shape = x.shape
|
113
|
+
reshaped = x.reshape((flat_size,) + orig_shape[len(batch_shape) :])
|
114
|
+
|
115
|
+
# Apply the permutation
|
116
|
+
permuted = reshaped[perm]
|
117
|
+
|
118
|
+
# Reshape back to the original shape
|
119
|
+
return permuted.reshape(orig_shape)
|
120
|
+
|
121
|
+
return jax.tree_util.tree_map(permute_array, data)
|
122
|
+
|
123
|
+
|
124
|
+
def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
|
125
|
+
"""Reshuffle a rollout array across arbitrary batch dimensions independently of each other."""
|
126
|
+
rngs = jax.random.split(rng, len(batch_shape))
|
127
|
+
perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape)]
|
128
|
+
# n-dimensional index grid from permutations
|
129
|
+
idx_grids = jnp.meshgrid(*perms, indexing="ij")
|
130
|
+
|
131
|
+
def permute_array(x: Array) -> Array:
|
132
|
+
if isinstance(x, Array):
|
133
|
+
return x[tuple(idx_grids)]
|
134
|
+
return x
|
135
|
+
|
136
|
+
return jax.tree_util.tree_map(permute_array, data)
|
137
|
+
|
138
|
+
|
139
|
+
TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
|
140
|
+
PathType = tuple[str | int, ...]
|
141
|
+
|
142
|
+
|
143
|
+
def reshuffle_pytree_along_dims(
|
144
|
+
data: PyTree,
|
145
|
+
dims: tuple[int, ...],
|
146
|
+
shape_dims: tuple[int, ...],
|
147
|
+
rng: PRNGKeyArray,
|
148
|
+
) -> PyTree:
|
149
|
+
"""Reshuffle a pytree along arbitrary dimensions.
|
150
|
+
|
151
|
+
Allows reshuffling along any dimensions, not just the leading ones.
|
152
|
+
It transposes the data to make the specified dimensions the leading ones,
|
153
|
+
then reshuffles, and finally transposes back.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
data: A pytree with arrays.
|
157
|
+
dims: A tuple of integers specifying which dimensions to shuffle along.
|
158
|
+
For example, (1,) would shuffle along the second dimension.
|
159
|
+
shape_dims: A tuple of integers specifying the size of each dimension to shuffle.
|
160
|
+
Must have the same length as dims.
|
161
|
+
rng: A JAX PRNG key.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
A new pytree with the same structure but with the data reshuffled along
|
165
|
+
the specified dimensions.
|
166
|
+
"""
|
167
|
+
if len(dims) != len(shape_dims):
|
168
|
+
raise ValueError(f"dims {dims} and shape_dims {shape_dims} must have the same length")
|
169
|
+
|
170
|
+
def transpose_for_shuffle(x: PyTree) -> TransposeResult:
|
171
|
+
if not isinstance(x, jnp.ndarray):
|
172
|
+
return x, (), ()
|
173
|
+
|
174
|
+
# Check if the array has enough dimensions
|
175
|
+
if len(x.shape) <= max(dims):
|
176
|
+
return x, (), ()
|
177
|
+
|
178
|
+
# Check if the dimensions match the shape_dims
|
179
|
+
for i, dim in enumerate(dims):
|
180
|
+
if x.shape[dim] != shape_dims[i]:
|
181
|
+
raise ValueError(f"Array shape {x.shape} doesn't match shape_dims {shape_dims} at dimension {dim}")
|
182
|
+
|
183
|
+
# Create the transpose order to move the specified dimensions to the front
|
184
|
+
# while preserving the relative order of the other dimensions
|
185
|
+
n_dims = len(x.shape)
|
186
|
+
other_dims = [i for i in range(n_dims) if i not in dims]
|
187
|
+
transpose_order = tuple(dims) + tuple(other_dims)
|
188
|
+
|
189
|
+
# Transpose the array
|
190
|
+
transposed = jnp.transpose(x, transpose_order)
|
191
|
+
|
192
|
+
return transposed, transpose_order, x.shape
|
193
|
+
|
194
|
+
def transpose_back(x: PyTree, transpose_order: tuple[int, ...], original_shape: tuple[int, ...]) -> PyTree:
|
195
|
+
if not isinstance(x, jnp.ndarray) or not transpose_order:
|
196
|
+
return x
|
197
|
+
|
198
|
+
# Create the inverse transpose order
|
199
|
+
inverse_order = [0] * len(transpose_order)
|
200
|
+
for i, j in enumerate(transpose_order):
|
201
|
+
inverse_order[j] = i
|
202
|
+
|
203
|
+
# Transpose back
|
204
|
+
return jnp.transpose(x, inverse_order)
|
205
|
+
|
206
|
+
# First, transpose all arrays to make the specified dimensions the leading ones
|
207
|
+
transposed_data: dict[PathType, Array] = {}
|
208
|
+
transpose_info: dict[PathType, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
209
|
+
|
210
|
+
def prepare_for_shuffle(path: PathType, x: PyTree) -> PyTree:
|
211
|
+
if isinstance(x, jnp.ndarray):
|
212
|
+
transposed, transpose_order, original_shape = transpose_for_shuffle(x)
|
213
|
+
if isinstance(transposed, jnp.ndarray): # Check if it's an array
|
214
|
+
transposed_data[path] = transposed
|
215
|
+
transpose_info[path] = (transpose_order, original_shape)
|
216
|
+
return x
|
217
|
+
|
218
|
+
jax.tree_util.tree_map_with_path(prepare_for_shuffle, data)
|
219
|
+
|
220
|
+
# Create a transposed pytree
|
221
|
+
def get_transposed(path: PathType, x: PyTree) -> PyTree:
|
222
|
+
if isinstance(x, jnp.ndarray) and path in transposed_data:
|
223
|
+
return transposed_data[path]
|
224
|
+
return x
|
225
|
+
|
226
|
+
transposed_pytree = jax.tree_util.tree_map_with_path(get_transposed, data)
|
227
|
+
|
228
|
+
# Reshuffle the transposed pytree along the leading dimensions
|
229
|
+
reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
|
230
|
+
|
231
|
+
# Transpose back
|
232
|
+
def restore_transpose(path: PathType, x: PyTree) -> PyTree:
|
233
|
+
if isinstance(x, jnp.ndarray) and path in transpose_info:
|
234
|
+
transpose_order, original_shape = transpose_info[path]
|
235
|
+
return transpose_back(x, transpose_order, original_shape)
|
236
|
+
return x
|
237
|
+
|
238
|
+
return jax.tree_util.tree_map_with_path(restore_transpose, reshuffled_transposed)
|
xax/utils/tensorboard.py
CHANGED
@@ -14,7 +14,7 @@ from PIL.Image import Image as PILImage
|
|
14
14
|
from tensorboard.compat.proto.config_pb2 import RunMetadata
|
15
15
|
from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
|
16
16
|
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
17
|
-
from tensorboard.compat.proto.summary_pb2 import Summary, SummaryMetadata
|
17
|
+
from tensorboard.compat.proto.summary_pb2 import HistogramProto, Summary, SummaryMetadata
|
18
18
|
from tensorboard.compat.proto.tensor_pb2 import TensorProto
|
19
19
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
20
20
|
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
|
@@ -25,6 +25,65 @@ from xax.core.state import Phase
|
|
25
25
|
ImageShape = Literal["HWC", "CHW", "HW", "NHWC", "NCHW", "NHW"]
|
26
26
|
|
27
27
|
|
28
|
+
def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | None = None) -> HistogramProto:
|
29
|
+
"""Convert values into a histogram proto using logic from histogram.cc.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
values: Input values to create histogram from
|
33
|
+
bins: Bin specification (string like 'auto' or array of bin edges)
|
34
|
+
max_bins: Optional maximum number of bins
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
HistogramProto containing the histogram data
|
38
|
+
"""
|
39
|
+
if values.size == 0:
|
40
|
+
raise ValueError("The input has no element.")
|
41
|
+
values = values.reshape(-1)
|
42
|
+
counts, limits = np.histogram(values, bins=bins)
|
43
|
+
num_bins = len(counts)
|
44
|
+
|
45
|
+
if max_bins is not None and num_bins > max_bins:
|
46
|
+
subsampling = num_bins // max_bins
|
47
|
+
subsampling_remainder = num_bins % subsampling
|
48
|
+
if subsampling_remainder != 0:
|
49
|
+
counts = np.pad(
|
50
|
+
counts,
|
51
|
+
pad_width=[[0, subsampling - subsampling_remainder]],
|
52
|
+
mode="constant",
|
53
|
+
constant_values=0,
|
54
|
+
)
|
55
|
+
counts = counts.reshape(-1, subsampling).sum(axis=-1)
|
56
|
+
new_limits = np.empty((counts.size + 1,), limits.dtype)
|
57
|
+
new_limits[:-1] = limits[:-1:subsampling]
|
58
|
+
new_limits[-1] = limits[-1]
|
59
|
+
limits = new_limits
|
60
|
+
|
61
|
+
# Find the first and the last bin defining the support of the histogram:
|
62
|
+
cum_counts = np.cumsum(np.greater(counts, 0))
|
63
|
+
start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
|
64
|
+
start = int(start)
|
65
|
+
end = int(end) + 1
|
66
|
+
|
67
|
+
# TensorBoard only includes the right bin limits. To still have the leftmost limit
|
68
|
+
# included, we include an empty bin left.
|
69
|
+
counts = counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
|
70
|
+
limits = limits[start : end + 1]
|
71
|
+
|
72
|
+
if counts.size == 0 or limits.size == 0:
|
73
|
+
raise ValueError("The histogram is empty, please file a bug report.")
|
74
|
+
|
75
|
+
sum_sq = values.dot(values)
|
76
|
+
return HistogramProto(
|
77
|
+
min=values.min(),
|
78
|
+
max=values.max(),
|
79
|
+
num=len(values),
|
80
|
+
sum=values.sum(),
|
81
|
+
sum_squares=sum_sq,
|
82
|
+
bucket_limit=limits.tolist(),
|
83
|
+
bucket=counts.tolist(),
|
84
|
+
)
|
85
|
+
|
86
|
+
|
28
87
|
class TensorboardProtobufWriter:
|
29
88
|
def __init__(
|
30
89
|
self,
|
@@ -263,6 +322,123 @@ class TensorboardWriter:
|
|
263
322
|
walltime=walltime,
|
264
323
|
)
|
265
324
|
|
325
|
+
def add_histogram(
|
326
|
+
self,
|
327
|
+
tag: str,
|
328
|
+
values: np.ndarray,
|
329
|
+
global_step: int | None = None,
|
330
|
+
bins: str | np.ndarray = "auto",
|
331
|
+
walltime: float | None = None,
|
332
|
+
max_bins: int | None = None,
|
333
|
+
) -> None:
|
334
|
+
hist = make_histogram(values.astype(float), bins, max_bins)
|
335
|
+
self.pb_writer.add_summary(
|
336
|
+
Summary(value=[Summary.Value(tag=tag, histo=hist)]),
|
337
|
+
global_step=global_step,
|
338
|
+
walltime=walltime,
|
339
|
+
)
|
340
|
+
|
341
|
+
def add_histogram_raw(
|
342
|
+
self,
|
343
|
+
tag: str,
|
344
|
+
min: float | np.ndarray,
|
345
|
+
max: float | np.ndarray,
|
346
|
+
num: int | np.ndarray,
|
347
|
+
sum: float | np.ndarray,
|
348
|
+
sum_squares: float | np.ndarray,
|
349
|
+
bucket_limits: list[float | np.ndarray],
|
350
|
+
bucket_counts: list[int | np.ndarray],
|
351
|
+
global_step: int | None = None,
|
352
|
+
walltime: float | None = None,
|
353
|
+
) -> None:
|
354
|
+
"""Add histogram with raw data to summary.
|
355
|
+
|
356
|
+
Args:
|
357
|
+
tag: Data identifier
|
358
|
+
min: Min value
|
359
|
+
max: Max value
|
360
|
+
num: Number of values
|
361
|
+
sum: Sum of all values
|
362
|
+
sum_squares: Sum of squares for all values
|
363
|
+
bucket_limits: Upper value per bucket
|
364
|
+
bucket_counts: Number of values per bucket
|
365
|
+
global_step: Global step value to record
|
366
|
+
walltime: Optional override default walltime
|
367
|
+
"""
|
368
|
+
if len(bucket_limits) != len(bucket_counts):
|
369
|
+
raise ValueError("len(bucket_limits) != len(bucket_counts)")
|
370
|
+
|
371
|
+
# Convert numpy arrays to Python types
|
372
|
+
hist = HistogramProto(
|
373
|
+
min=float(min),
|
374
|
+
max=float(max),
|
375
|
+
num=int(num),
|
376
|
+
sum=float(sum),
|
377
|
+
sum_squares=float(sum_squares),
|
378
|
+
bucket_limit=[float(x) for x in bucket_limits],
|
379
|
+
bucket=[int(x) for x in bucket_counts],
|
380
|
+
)
|
381
|
+
self.pb_writer.add_summary(
|
382
|
+
Summary(value=[Summary.Value(tag=tag, histo=hist)]),
|
383
|
+
global_step=global_step,
|
384
|
+
walltime=walltime,
|
385
|
+
)
|
386
|
+
|
387
|
+
def add_gaussian_distribution(
|
388
|
+
self,
|
389
|
+
tag: str,
|
390
|
+
mean: float | np.ndarray,
|
391
|
+
std: float | np.ndarray,
|
392
|
+
bins: int = 16,
|
393
|
+
global_step: int | None = None,
|
394
|
+
walltime: float | None = None,
|
395
|
+
) -> None:
|
396
|
+
"""Add a Gaussian distribution to the summary.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
tag: Data identifier
|
400
|
+
mean: Mean of the Gaussian distribution
|
401
|
+
std: Standard deviation of the Gaussian distribution
|
402
|
+
bins: Number of bins to use for the histogram
|
403
|
+
global_step: Global step value to record
|
404
|
+
walltime: Optional override default walltime
|
405
|
+
"""
|
406
|
+
# Convert numpy arrays to Python types
|
407
|
+
mean = float(mean)
|
408
|
+
std = float(std)
|
409
|
+
|
410
|
+
# Create bin edges spanning ±3 standard deviations
|
411
|
+
bin_edges = np.linspace(mean - 3 * std, mean + 3 * std, bins + 1)
|
412
|
+
|
413
|
+
# Calculate the probability density for each bin
|
414
|
+
bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
|
415
|
+
gaussian_pdf = np.exp(-0.5 * ((bin_centers - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi))
|
416
|
+
|
417
|
+
# Scale the PDF to represent counts
|
418
|
+
num_samples = bins * 1000
|
419
|
+
bucket_counts = (gaussian_pdf * num_samples * (bin_edges[1] - bin_edges[0])).astype(int)
|
420
|
+
|
421
|
+
# Ensure we have at least one count per bin for visualization
|
422
|
+
bucket_counts = np.maximum(bucket_counts, 1)
|
423
|
+
|
424
|
+
# Calculate actual statistics based on the discretized distribution
|
425
|
+
total_counts = float(bucket_counts.sum())
|
426
|
+
weighted_sum = float((bin_centers * bucket_counts).sum())
|
427
|
+
weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
|
428
|
+
|
429
|
+
self.add_histogram_raw(
|
430
|
+
tag=tag,
|
431
|
+
min=float(bin_edges[0]),
|
432
|
+
max=float(bin_edges[-1]),
|
433
|
+
num=int(total_counts),
|
434
|
+
sum=weighted_sum,
|
435
|
+
sum_squares=weighted_sum_squares,
|
436
|
+
bucket_limits=bin_edges[1:].tolist(), # TensorBoard expects right bin edges
|
437
|
+
bucket_counts=bucket_counts.tolist(),
|
438
|
+
global_step=global_step,
|
439
|
+
walltime=walltime,
|
440
|
+
)
|
441
|
+
|
266
442
|
|
267
443
|
class TensorboardWriterKwargs(TypedDict):
|
268
444
|
max_queue_size: int
|
@@ -1,12 +1,13 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.0
|
4
|
-
Summary:
|
5
|
-
Home-page: https://github.com/
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: A library for fast Jax experimentation
|
5
|
+
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
7
7
|
Requires-Python: >=3.11
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
|
+
Requires-Dist: attrs
|
10
11
|
Requires-Dist: jax
|
11
12
|
Requires-Dist: jaxtyping
|
12
13
|
Requires-Dist: equinox
|
@@ -30,10 +31,28 @@ Requires-Dist: pytest; extra == "dev"
|
|
30
31
|
Requires-Dist: types-pillow; extra == "dev"
|
31
32
|
Requires-Dist: types-psutil; extra == "dev"
|
32
33
|
Requires-Dist: types-requests; extra == "dev"
|
34
|
+
Provides-Extra: export
|
35
|
+
Requires-Dist: orbax-export; extra == "export"
|
36
|
+
Requires-Dist: tensorflow; extra == "export"
|
37
|
+
Provides-Extra: flax
|
38
|
+
Requires-Dist: flax; extra == "flax"
|
39
|
+
Provides-Extra: all
|
40
|
+
Requires-Dist: black; extra == "all"
|
41
|
+
Requires-Dist: darglint; extra == "all"
|
42
|
+
Requires-Dist: mypy; extra == "all"
|
43
|
+
Requires-Dist: ruff; extra == "all"
|
44
|
+
Requires-Dist: pytest; extra == "all"
|
45
|
+
Requires-Dist: types-pillow; extra == "all"
|
46
|
+
Requires-Dist: types-psutil; extra == "all"
|
47
|
+
Requires-Dist: types-requests; extra == "all"
|
48
|
+
Requires-Dist: orbax-export; extra == "all"
|
49
|
+
Requires-Dist: tensorflow; extra == "all"
|
50
|
+
Requires-Dist: flax; extra == "all"
|
33
51
|
Dynamic: author
|
34
52
|
Dynamic: description
|
35
53
|
Dynamic: description-content-type
|
36
54
|
Dynamic: home-page
|
55
|
+
Dynamic: license-file
|
37
56
|
Dynamic: provides-extra
|
38
57
|
Dynamic: requires-dist
|
39
58
|
Dynamic: requires-python
|