xax 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/script.py CHANGED
@@ -17,8 +17,6 @@ from xax.task.mixins import (
17
17
  ProcessMixin,
18
18
  RunnableConfig,
19
19
  RunnableMixin,
20
- StepContextConfig,
21
- StepContextMixin,
22
20
  )
23
21
 
24
22
 
@@ -28,7 +26,6 @@ class ScriptConfig(
28
26
  GPUStatsConfig,
29
27
  ProcessConfig,
30
28
  LoggerConfig,
31
- StepContextConfig,
32
29
  ArtifactsConfig,
33
30
  RunnableConfig,
34
31
  BaseConfig,
@@ -44,7 +41,6 @@ class Script(
44
41
  GPUStatsMixin[ConfigT],
45
42
  ProcessMixin[ConfigT],
46
43
  LoggerMixin[ConfigT],
47
- StepContextMixin[ConfigT],
48
44
  ArtifactsMixin[ConfigT],
49
45
  RunnableMixin[ConfigT],
50
46
  BaseTask[ConfigT],
xax/utils/debugging.py ADDED
@@ -0,0 +1,49 @@
1
+ """Defines some useful Jax debugging utilities."""
2
+
3
+ from collections import deque
4
+ from collections.abc import Iterable, Mapping
5
+ from typing import Any, Callable, Deque
6
+
7
+ from jaxtyping import Array
8
+
9
+
10
+ def get_named_leaves(
11
+ obj: Any, # noqa: ANN401
12
+ is_leaf: Callable[[Any], bool] = lambda x: isinstance(x, Array), # noqa: ANN401
13
+ max_depth: int = 100,
14
+ ) -> list[tuple[str, Any]]: # noqa: ANN401
15
+ ret: list[tuple[str, Any]] = []
16
+ q: Deque[tuple[int, str, Any]] = deque() # noqa: ANN401
17
+ q.append((0, "", obj))
18
+
19
+ while q:
20
+ depth, name, node = q.popleft()
21
+
22
+ if depth > max_depth:
23
+ continue
24
+
25
+ if hasattr(node, "__dict__") and isinstance(node.__dict__, Mapping):
26
+ for cname, cnode in node.__dict__.items():
27
+ gname = f"{name}.{cname}" if name else cname
28
+ if is_leaf(cnode):
29
+ ret.append((gname, cnode))
30
+ else:
31
+ q.append((depth + 1, gname, cnode))
32
+
33
+ elif isinstance(node, Mapping):
34
+ for cname, cnode in node.items():
35
+ gname = f"{name}.{cname}" if name else cname
36
+ if is_leaf(cnode):
37
+ ret.append((gname, cnode))
38
+ else:
39
+ q.append((depth + 1, gname, cnode))
40
+
41
+ elif isinstance(node, Iterable):
42
+ for i, cnode in enumerate(node):
43
+ gname = f"{name}.{i}" if name else str(i)
44
+ if is_leaf(cnode):
45
+ ret.append((gname, cnode))
46
+ else:
47
+ q.append((depth + 1, gname, cnode))
48
+
49
+ return ret
xax/utils/experiments.py CHANGED
@@ -23,7 +23,8 @@ import urllib.request
23
23
  import warnings
24
24
  from abc import ABC, abstractmethod
25
25
  from pathlib import Path
26
- from typing import Any, Iterator, TypeVar, cast
26
+ from types import TracebackType
27
+ from typing import Any, Iterator, Self, TypeVar, cast
27
28
  from urllib.parse import urlparse
28
29
 
29
30
  import git
@@ -116,19 +117,19 @@ class StateTimer:
116
117
  logs: dict[str, dict[str, int | float]] = {}
117
118
 
118
119
  # Logs step statistics.
119
- logs[" steps"] = {
120
+ logs[" steps"] = {
120
121
  "total": self.step_timer.steps,
121
122
  "per-second": self.step_timer.steps_per_second,
122
123
  }
123
124
 
124
125
  # Logs sample statistics.
125
- logs[" samples"] = {
126
+ logs[" samples"] = {
126
127
  "total": self.sample_timer.steps,
127
128
  "per-second": self.sample_timer.steps_per_second,
128
129
  }
129
130
 
130
131
  # Logs full iteration statistics.
131
- logs["🔧 dt"] = {
132
+ logs[" dt"] = {
132
133
  "iter": self.iter_timer.iter_seconds,
133
134
  }
134
135
 
@@ -147,6 +148,24 @@ class IntervalTicker:
147
148
  return False
148
149
 
149
150
 
151
+ class ContextTimer:
152
+ def __init__(self) -> None:
153
+ self.start_time = 0.0
154
+ self.elapsed_time = 0.0
155
+
156
+ def __enter__(self) -> Self:
157
+ self.start_time = time.time()
158
+ return self
159
+
160
+ def __exit__(
161
+ self,
162
+ exc_type: type[BaseException] | None,
163
+ exc_value: BaseException | None,
164
+ traceback: TracebackType | None,
165
+ ) -> None:
166
+ self.elapsed_time = time.time() - self.start_time
167
+
168
+
150
169
  def abs_path(path: str) -> str:
151
170
  return str(Path(path).resolve())
152
171
 
xax/utils/jaxpr.py ADDED
@@ -0,0 +1,77 @@
1
+ """Visualize JAXPR."""
2
+
3
+ from pathlib import Path
4
+
5
+ import jax
6
+ import jax.core
7
+
8
+
9
+ def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
10
+ """Save the JAXPR to a DOT file.
11
+
12
+ Example usage:
13
+
14
+ grad_fn_jaxpr = jax.make_jaxpr(loss_fn)(variables)
15
+ save_jaxpr_dot(grad_fn_jaxpr, "grad_fn_jaxpr.dot")
16
+
17
+ Then, you can visualize the JAXPR using Graphviz:
18
+
19
+ dot -Tpng grad_fn_jaxpr.dot > grad_fn_jaxpr.png
20
+
21
+ Args:
22
+ closed_jaxpr: The closed JAXPR to save.
23
+ filename: The filename to save the JAXPR to.
24
+ """
25
+ if hasattr(closed_jaxpr, "jaxpr"):
26
+ jaxpr = closed_jaxpr.jaxpr
27
+ else:
28
+ jaxpr = closed_jaxpr
29
+
30
+ with open(filename, "w") as f:
31
+ f.write("digraph Jaxpr {\n")
32
+
33
+ var_names: dict[jax.core.Var, str] = {}
34
+ var_count = 0
35
+
36
+ def get_var_name(var: jax.core.Var) -> str:
37
+ """Get a unique name for a variable."""
38
+ nonlocal var_names, var_count
39
+
40
+ # Handle Literal objects specially since they're not hashable
41
+ if isinstance(var, jax.core.Literal):
42
+ # Create a name based on the literal value
43
+ name = f"lit_{var.val}"
44
+ return name
45
+
46
+ # For other variables
47
+ if var not in var_names:
48
+ name = f"var_{var_count}"
49
+ var_names[var] = name
50
+ var_count += 1
51
+ return var_names[var]
52
+
53
+ for var in jaxpr.invars:
54
+ node_name = get_var_name(var)
55
+ f.write(f' {node_name} [label="{node_name}\\n(input)"];\n')
56
+
57
+ eq_count = 0
58
+ for eq in jaxpr.eqns:
59
+ eq_node = f"eq{eq_count}"
60
+ label = f"{eq.primitive.name}"
61
+ f.write(f' {eq_node} [shape=box, label="{label}"];\n')
62
+
63
+ for invar in eq.invars:
64
+ var_name = get_var_name(invar)
65
+ f.write(f" {var_name} -> {eq_node};\n")
66
+
67
+ for outvar in eq.outvars:
68
+ var_name = get_var_name(outvar)
69
+ f.write(f" {eq_node} -> {var_name};\n")
70
+
71
+ eq_count += 1
72
+
73
+ for var in jaxpr.outvars:
74
+ node_name = get_var_name(var)
75
+ f.write(f' {node_name} [peripheries=2, label="{node_name}\\n(output)"];\n')
76
+
77
+ f.write("}\n")
xax/utils/pytree.py CHANGED
@@ -4,7 +4,7 @@ import chex
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  from jax import Array
7
- from jaxtyping import PyTree
7
+ from jaxtyping import PRNGKeyArray, PyTree
8
8
 
9
9
 
10
10
  def slice_array(x: Array, start: Array, slice_length: int) -> Array:
@@ -12,6 +12,14 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
12
12
 
13
13
  For multi-dimensional arrays, this slices only along the first dimension
14
14
  and keeps all other dimensions intact.
15
+
16
+ Args:
17
+ x: The array to slice.
18
+ start: The start index of the slice.
19
+ slice_length: The length of the slice.
20
+
21
+ Returns:
22
+ The sliced array.
15
23
  """
16
24
  chex.assert_shape(start, ())
17
25
  chex.assert_shape(slice_length, ())
@@ -38,6 +46,21 @@ def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
38
46
  return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
39
47
 
40
48
 
49
+ def pytree_has_nans(pytree: PyTree) -> Array:
50
+ """Check if a pytree has any NaNs."""
51
+ has_nans = jax.tree_util.tree_reduce(
52
+ lambda a, b: jnp.logical_or(a, b),
53
+ jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree),
54
+ )
55
+ return has_nans
56
+
57
+
58
+ def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
59
+ """Update a pytree based on a condition."""
60
+ # Tricky, need use tree_map because where expects array leafs.
61
+ return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), new, original)
62
+
63
+
41
64
  def compute_nan_ratio(pytree: PyTree) -> Array:
42
65
  """Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
43
66
  nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
@@ -48,3 +71,168 @@ def compute_nan_ratio(pytree: PyTree) -> Array:
48
71
  overall_nan_ratio = jnp.array(total_nans / total_elements)
49
72
 
50
73
  return overall_nan_ratio
74
+
75
+
76
+ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
77
+ """Reshuffle a pytree along the leading dimensions.
78
+
79
+ This function reshuffles the data along the leading dimensions specified by batch_shape.
80
+ Assumes the dimensions to shuffle are the leading ones.
81
+
82
+ Args:
83
+ data: A pytree with arrays.
84
+ batch_shape: A tuple of integers specifying the size of each leading dimension to shuffle.
85
+ rng: A JAX PRNG key.
86
+
87
+ Returns:
88
+ A new pytree with the same structure but with the data reshuffled along
89
+ the leading dimensions.
90
+ """
91
+ # Create a permutation for the flattened batch dimensions
92
+ flat_size = 1
93
+ for dim in batch_shape:
94
+ flat_size *= dim
95
+
96
+ perm = jax.random.permutation(rng, flat_size)
97
+
98
+ def permute_array(x: Array) -> Array:
99
+ if not isinstance(x, jnp.ndarray):
100
+ return x
101
+
102
+ # Check if the array has enough dimensions
103
+ if len(x.shape) < len(batch_shape):
104
+ return x
105
+
106
+ # Check if the dimensions match the batch_shape
107
+ for i, dim in enumerate(batch_shape):
108
+ if x.shape[i] != dim:
109
+ return x
110
+
111
+ # Reshape to flatten the batch dimensions
112
+ orig_shape = x.shape
113
+ reshaped = x.reshape((flat_size,) + orig_shape[len(batch_shape) :])
114
+
115
+ # Apply the permutation
116
+ permuted = reshaped[perm]
117
+
118
+ # Reshape back to the original shape
119
+ return permuted.reshape(orig_shape)
120
+
121
+ return jax.tree_util.tree_map(permute_array, data)
122
+
123
+
124
+ def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
125
+ """Reshuffle a rollout array across arbitrary batch dimensions independently of each other."""
126
+ rngs = jax.random.split(rng, len(batch_shape))
127
+ perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape)]
128
+ # n-dimensional index grid from permutations
129
+ idx_grids = jnp.meshgrid(*perms, indexing="ij")
130
+
131
+ def permute_array(x: Array) -> Array:
132
+ if isinstance(x, Array):
133
+ return x[tuple(idx_grids)]
134
+ return x
135
+
136
+ return jax.tree_util.tree_map(permute_array, data)
137
+
138
+
139
+ TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
140
+ PathType = tuple[str | int, ...]
141
+
142
+
143
+ def reshuffle_pytree_along_dims(
144
+ data: PyTree,
145
+ dims: tuple[int, ...],
146
+ shape_dims: tuple[int, ...],
147
+ rng: PRNGKeyArray,
148
+ ) -> PyTree:
149
+ """Reshuffle a pytree along arbitrary dimensions.
150
+
151
+ Allows reshuffling along any dimensions, not just the leading ones.
152
+ It transposes the data to make the specified dimensions the leading ones,
153
+ then reshuffles, and finally transposes back.
154
+
155
+ Args:
156
+ data: A pytree with arrays.
157
+ dims: A tuple of integers specifying which dimensions to shuffle along.
158
+ For example, (1,) would shuffle along the second dimension.
159
+ shape_dims: A tuple of integers specifying the size of each dimension to shuffle.
160
+ Must have the same length as dims.
161
+ rng: A JAX PRNG key.
162
+
163
+ Returns:
164
+ A new pytree with the same structure but with the data reshuffled along
165
+ the specified dimensions.
166
+ """
167
+ if len(dims) != len(shape_dims):
168
+ raise ValueError(f"dims {dims} and shape_dims {shape_dims} must have the same length")
169
+
170
+ def transpose_for_shuffle(x: PyTree) -> TransposeResult:
171
+ if not isinstance(x, jnp.ndarray):
172
+ return x, (), ()
173
+
174
+ # Check if the array has enough dimensions
175
+ if len(x.shape) <= max(dims):
176
+ return x, (), ()
177
+
178
+ # Check if the dimensions match the shape_dims
179
+ for i, dim in enumerate(dims):
180
+ if x.shape[dim] != shape_dims[i]:
181
+ raise ValueError(f"Array shape {x.shape} doesn't match shape_dims {shape_dims} at dimension {dim}")
182
+
183
+ # Create the transpose order to move the specified dimensions to the front
184
+ # while preserving the relative order of the other dimensions
185
+ n_dims = len(x.shape)
186
+ other_dims = [i for i in range(n_dims) if i not in dims]
187
+ transpose_order = tuple(dims) + tuple(other_dims)
188
+
189
+ # Transpose the array
190
+ transposed = jnp.transpose(x, transpose_order)
191
+
192
+ return transposed, transpose_order, x.shape
193
+
194
+ def transpose_back(x: PyTree, transpose_order: tuple[int, ...], original_shape: tuple[int, ...]) -> PyTree:
195
+ if not isinstance(x, jnp.ndarray) or not transpose_order:
196
+ return x
197
+
198
+ # Create the inverse transpose order
199
+ inverse_order = [0] * len(transpose_order)
200
+ for i, j in enumerate(transpose_order):
201
+ inverse_order[j] = i
202
+
203
+ # Transpose back
204
+ return jnp.transpose(x, inverse_order)
205
+
206
+ # First, transpose all arrays to make the specified dimensions the leading ones
207
+ transposed_data: dict[PathType, Array] = {}
208
+ transpose_info: dict[PathType, tuple[tuple[int, ...], tuple[int, ...]]] = {}
209
+
210
+ def prepare_for_shuffle(path: PathType, x: PyTree) -> PyTree:
211
+ if isinstance(x, jnp.ndarray):
212
+ transposed, transpose_order, original_shape = transpose_for_shuffle(x)
213
+ if isinstance(transposed, jnp.ndarray): # Check if it's an array
214
+ transposed_data[path] = transposed
215
+ transpose_info[path] = (transpose_order, original_shape)
216
+ return x
217
+
218
+ jax.tree_util.tree_map_with_path(prepare_for_shuffle, data)
219
+
220
+ # Create a transposed pytree
221
+ def get_transposed(path: PathType, x: PyTree) -> PyTree:
222
+ if isinstance(x, jnp.ndarray) and path in transposed_data:
223
+ return transposed_data[path]
224
+ return x
225
+
226
+ transposed_pytree = jax.tree_util.tree_map_with_path(get_transposed, data)
227
+
228
+ # Reshuffle the transposed pytree along the leading dimensions
229
+ reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
230
+
231
+ # Transpose back
232
+ def restore_transpose(path: PathType, x: PyTree) -> PyTree:
233
+ if isinstance(x, jnp.ndarray) and path in transpose_info:
234
+ transpose_order, original_shape = transpose_info[path]
235
+ return transpose_back(x, transpose_order, original_shape)
236
+ return x
237
+
238
+ return jax.tree_util.tree_map_with_path(restore_transpose, reshuffled_transposed)
xax/utils/tensorboard.py CHANGED
@@ -14,7 +14,7 @@ from PIL.Image import Image as PILImage
14
14
  from tensorboard.compat.proto.config_pb2 import RunMetadata
15
15
  from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
16
16
  from tensorboard.compat.proto.graph_pb2 import GraphDef
17
- from tensorboard.compat.proto.summary_pb2 import Summary, SummaryMetadata
17
+ from tensorboard.compat.proto.summary_pb2 import HistogramProto, Summary, SummaryMetadata
18
18
  from tensorboard.compat.proto.tensor_pb2 import TensorProto
19
19
  from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
20
20
  from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
@@ -25,6 +25,65 @@ from xax.core.state import Phase
25
25
  ImageShape = Literal["HWC", "CHW", "HW", "NHWC", "NCHW", "NHW"]
26
26
 
27
27
 
28
+ def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | None = None) -> HistogramProto:
29
+ """Convert values into a histogram proto using logic from histogram.cc.
30
+
31
+ Args:
32
+ values: Input values to create histogram from
33
+ bins: Bin specification (string like 'auto' or array of bin edges)
34
+ max_bins: Optional maximum number of bins
35
+
36
+ Returns:
37
+ HistogramProto containing the histogram data
38
+ """
39
+ if values.size == 0:
40
+ raise ValueError("The input has no element.")
41
+ values = values.reshape(-1)
42
+ counts, limits = np.histogram(values, bins=bins)
43
+ num_bins = len(counts)
44
+
45
+ if max_bins is not None and num_bins > max_bins:
46
+ subsampling = num_bins // max_bins
47
+ subsampling_remainder = num_bins % subsampling
48
+ if subsampling_remainder != 0:
49
+ counts = np.pad(
50
+ counts,
51
+ pad_width=[[0, subsampling - subsampling_remainder]],
52
+ mode="constant",
53
+ constant_values=0,
54
+ )
55
+ counts = counts.reshape(-1, subsampling).sum(axis=-1)
56
+ new_limits = np.empty((counts.size + 1,), limits.dtype)
57
+ new_limits[:-1] = limits[:-1:subsampling]
58
+ new_limits[-1] = limits[-1]
59
+ limits = new_limits
60
+
61
+ # Find the first and the last bin defining the support of the histogram:
62
+ cum_counts = np.cumsum(np.greater(counts, 0))
63
+ start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
64
+ start = int(start)
65
+ end = int(end) + 1
66
+
67
+ # TensorBoard only includes the right bin limits. To still have the leftmost limit
68
+ # included, we include an empty bin left.
69
+ counts = counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
70
+ limits = limits[start : end + 1]
71
+
72
+ if counts.size == 0 or limits.size == 0:
73
+ raise ValueError("The histogram is empty, please file a bug report.")
74
+
75
+ sum_sq = values.dot(values)
76
+ return HistogramProto(
77
+ min=values.min(),
78
+ max=values.max(),
79
+ num=len(values),
80
+ sum=values.sum(),
81
+ sum_squares=sum_sq,
82
+ bucket_limit=limits.tolist(),
83
+ bucket=counts.tolist(),
84
+ )
85
+
86
+
28
87
  class TensorboardProtobufWriter:
29
88
  def __init__(
30
89
  self,
@@ -263,6 +322,123 @@ class TensorboardWriter:
263
322
  walltime=walltime,
264
323
  )
265
324
 
325
+ def add_histogram(
326
+ self,
327
+ tag: str,
328
+ values: np.ndarray,
329
+ global_step: int | None = None,
330
+ bins: str | np.ndarray = "auto",
331
+ walltime: float | None = None,
332
+ max_bins: int | None = None,
333
+ ) -> None:
334
+ hist = make_histogram(values.astype(float), bins, max_bins)
335
+ self.pb_writer.add_summary(
336
+ Summary(value=[Summary.Value(tag=tag, histo=hist)]),
337
+ global_step=global_step,
338
+ walltime=walltime,
339
+ )
340
+
341
+ def add_histogram_raw(
342
+ self,
343
+ tag: str,
344
+ min: float | np.ndarray,
345
+ max: float | np.ndarray,
346
+ num: int | np.ndarray,
347
+ sum: float | np.ndarray,
348
+ sum_squares: float | np.ndarray,
349
+ bucket_limits: list[float | np.ndarray],
350
+ bucket_counts: list[int | np.ndarray],
351
+ global_step: int | None = None,
352
+ walltime: float | None = None,
353
+ ) -> None:
354
+ """Add histogram with raw data to summary.
355
+
356
+ Args:
357
+ tag: Data identifier
358
+ min: Min value
359
+ max: Max value
360
+ num: Number of values
361
+ sum: Sum of all values
362
+ sum_squares: Sum of squares for all values
363
+ bucket_limits: Upper value per bucket
364
+ bucket_counts: Number of values per bucket
365
+ global_step: Global step value to record
366
+ walltime: Optional override default walltime
367
+ """
368
+ if len(bucket_limits) != len(bucket_counts):
369
+ raise ValueError("len(bucket_limits) != len(bucket_counts)")
370
+
371
+ # Convert numpy arrays to Python types
372
+ hist = HistogramProto(
373
+ min=float(min),
374
+ max=float(max),
375
+ num=int(num),
376
+ sum=float(sum),
377
+ sum_squares=float(sum_squares),
378
+ bucket_limit=[float(x) for x in bucket_limits],
379
+ bucket=[int(x) for x in bucket_counts],
380
+ )
381
+ self.pb_writer.add_summary(
382
+ Summary(value=[Summary.Value(tag=tag, histo=hist)]),
383
+ global_step=global_step,
384
+ walltime=walltime,
385
+ )
386
+
387
+ def add_gaussian_distribution(
388
+ self,
389
+ tag: str,
390
+ mean: float | np.ndarray,
391
+ std: float | np.ndarray,
392
+ bins: int = 16,
393
+ global_step: int | None = None,
394
+ walltime: float | None = None,
395
+ ) -> None:
396
+ """Add a Gaussian distribution to the summary.
397
+
398
+ Args:
399
+ tag: Data identifier
400
+ mean: Mean of the Gaussian distribution
401
+ std: Standard deviation of the Gaussian distribution
402
+ bins: Number of bins to use for the histogram
403
+ global_step: Global step value to record
404
+ walltime: Optional override default walltime
405
+ """
406
+ # Convert numpy arrays to Python types
407
+ mean = float(mean)
408
+ std = float(std)
409
+
410
+ # Create bin edges spanning ±3 standard deviations
411
+ bin_edges = np.linspace(mean - 3 * std, mean + 3 * std, bins + 1)
412
+
413
+ # Calculate the probability density for each bin
414
+ bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
415
+ gaussian_pdf = np.exp(-0.5 * ((bin_centers - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi))
416
+
417
+ # Scale the PDF to represent counts
418
+ num_samples = bins * 1000
419
+ bucket_counts = (gaussian_pdf * num_samples * (bin_edges[1] - bin_edges[0])).astype(int)
420
+
421
+ # Ensure we have at least one count per bin for visualization
422
+ bucket_counts = np.maximum(bucket_counts, 1)
423
+
424
+ # Calculate actual statistics based on the discretized distribution
425
+ total_counts = float(bucket_counts.sum())
426
+ weighted_sum = float((bin_centers * bucket_counts).sum())
427
+ weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
428
+
429
+ self.add_histogram_raw(
430
+ tag=tag,
431
+ min=float(bin_edges[0]),
432
+ max=float(bin_edges[-1]),
433
+ num=int(total_counts),
434
+ sum=weighted_sum,
435
+ sum_squares=weighted_sum_squares,
436
+ bucket_limits=bin_edges[1:].tolist(), # TensorBoard expects right bin edges
437
+ bucket_counts=bucket_counts.tolist(),
438
+ global_step=global_step,
439
+ walltime=walltime,
440
+ )
441
+
266
442
 
267
443
  class TensorboardWriterKwargs(TypedDict):
268
444
  max_queue_size: int
@@ -1,12 +1,13 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.0.7
4
- Summary: The xax project
5
- Home-page: https://github.com/dpshai/xax
3
+ Version: 0.1.0
4
+ Summary: A library for fast Jax experimentation
5
+ Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
+ Requires-Dist: attrs
10
11
  Requires-Dist: jax
11
12
  Requires-Dist: jaxtyping
12
13
  Requires-Dist: equinox
@@ -30,10 +31,28 @@ Requires-Dist: pytest; extra == "dev"
30
31
  Requires-Dist: types-pillow; extra == "dev"
31
32
  Requires-Dist: types-psutil; extra == "dev"
32
33
  Requires-Dist: types-requests; extra == "dev"
34
+ Provides-Extra: export
35
+ Requires-Dist: orbax-export; extra == "export"
36
+ Requires-Dist: tensorflow; extra == "export"
37
+ Provides-Extra: flax
38
+ Requires-Dist: flax; extra == "flax"
39
+ Provides-Extra: all
40
+ Requires-Dist: black; extra == "all"
41
+ Requires-Dist: darglint; extra == "all"
42
+ Requires-Dist: mypy; extra == "all"
43
+ Requires-Dist: ruff; extra == "all"
44
+ Requires-Dist: pytest; extra == "all"
45
+ Requires-Dist: types-pillow; extra == "all"
46
+ Requires-Dist: types-psutil; extra == "all"
47
+ Requires-Dist: types-requests; extra == "all"
48
+ Requires-Dist: orbax-export; extra == "all"
49
+ Requires-Dist: tensorflow; extra == "all"
50
+ Requires-Dist: flax; extra == "all"
33
51
  Dynamic: author
34
52
  Dynamic: description
35
53
  Dynamic: description-content-type
36
54
  Dynamic: home-page
55
+ Dynamic: license-file
37
56
  Dynamic: provides-extra
38
57
  Dynamic: requires-dist
39
58
  Dynamic: requires-python