xax 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/utils/pytree.py CHANGED
@@ -4,7 +4,7 @@ import chex
4
4
  import jax
5
5
  import jax.numpy as jnp
6
6
  from jax import Array
7
- from jaxtyping import PyTree
7
+ from jaxtyping import PRNGKeyArray, PyTree
8
8
 
9
9
 
10
10
  def slice_array(x: Array, start: Array, slice_length: int) -> Array:
@@ -12,6 +12,14 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
12
12
 
13
13
  For multi-dimensional arrays, this slices only along the first dimension
14
14
  and keeps all other dimensions intact.
15
+
16
+ Args:
17
+ x: The array to slice.
18
+ start: The start index of the slice.
19
+ slice_length: The length of the slice.
20
+
21
+ Returns:
22
+ The sliced array.
15
23
  """
16
24
  chex.assert_shape(start, ())
17
25
  chex.assert_shape(slice_length, ())
@@ -38,6 +46,21 @@ def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
38
46
  return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
39
47
 
40
48
 
49
+ def pytree_has_nans(pytree: PyTree) -> Array:
50
+ """Check if a pytree has any NaNs."""
51
+ has_nans = jax.tree_util.tree_reduce(
52
+ lambda a, b: jnp.logical_or(a, b),
53
+ jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree),
54
+ )
55
+ return has_nans
56
+
57
+
58
+ def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
59
+ """Update a pytree based on a condition."""
60
+ # Tricky, need use tree_map because where expects array leafs.
61
+ return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), new, original)
62
+
63
+
41
64
  def compute_nan_ratio(pytree: PyTree) -> Array:
42
65
  """Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
43
66
  nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
@@ -48,3 +71,168 @@ def compute_nan_ratio(pytree: PyTree) -> Array:
48
71
  overall_nan_ratio = jnp.array(total_nans / total_elements)
49
72
 
50
73
  return overall_nan_ratio
74
+
75
+
76
+ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
77
+ """Reshuffle a pytree along the leading dimensions.
78
+
79
+ This function reshuffles the data along the leading dimensions specified by batch_shape.
80
+ Assumes the dimensions to shuffle are the leading ones.
81
+
82
+ Args:
83
+ data: A pytree with arrays.
84
+ batch_shape: A tuple of integers specifying the size of each leading dimension to shuffle.
85
+ rng: A JAX PRNG key.
86
+
87
+ Returns:
88
+ A new pytree with the same structure but with the data reshuffled along
89
+ the leading dimensions.
90
+ """
91
+ # Create a permutation for the flattened batch dimensions
92
+ flat_size = 1
93
+ for dim in batch_shape:
94
+ flat_size *= dim
95
+
96
+ perm = jax.random.permutation(rng, flat_size)
97
+
98
+ def permute_array(x: Array) -> Array:
99
+ if not isinstance(x, jnp.ndarray):
100
+ return x
101
+
102
+ # Check if the array has enough dimensions
103
+ if len(x.shape) < len(batch_shape):
104
+ return x
105
+
106
+ # Check if the dimensions match the batch_shape
107
+ for i, dim in enumerate(batch_shape):
108
+ if x.shape[i] != dim:
109
+ return x
110
+
111
+ # Reshape to flatten the batch dimensions
112
+ orig_shape = x.shape
113
+ reshaped = x.reshape((flat_size,) + orig_shape[len(batch_shape) :])
114
+
115
+ # Apply the permutation
116
+ permuted = reshaped[perm]
117
+
118
+ # Reshape back to the original shape
119
+ return permuted.reshape(orig_shape)
120
+
121
+ return jax.tree_util.tree_map(permute_array, data)
122
+
123
+
124
+ def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
125
+ """Reshuffle a rollout array across arbitrary batch dimensions independently of each other."""
126
+ rngs = jax.random.split(rng, len(batch_shape))
127
+ perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape)]
128
+ # n-dimensional index grid from permutations
129
+ idx_grids = jnp.meshgrid(*perms, indexing="ij")
130
+
131
+ def permute_array(x: Array) -> Array:
132
+ if isinstance(x, Array):
133
+ return x[tuple(idx_grids)]
134
+ return x
135
+
136
+ return jax.tree_util.tree_map(permute_array, data)
137
+
138
+
139
+ TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
140
+ PathType = tuple[str | int, ...]
141
+
142
+
143
+ def reshuffle_pytree_along_dims(
144
+ data: PyTree,
145
+ dims: tuple[int, ...],
146
+ shape_dims: tuple[int, ...],
147
+ rng: PRNGKeyArray,
148
+ ) -> PyTree:
149
+ """Reshuffle a pytree along arbitrary dimensions.
150
+
151
+ Allows reshuffling along any dimensions, not just the leading ones.
152
+ It transposes the data to make the specified dimensions the leading ones,
153
+ then reshuffles, and finally transposes back.
154
+
155
+ Args:
156
+ data: A pytree with arrays.
157
+ dims: A tuple of integers specifying which dimensions to shuffle along.
158
+ For example, (1,) would shuffle along the second dimension.
159
+ shape_dims: A tuple of integers specifying the size of each dimension to shuffle.
160
+ Must have the same length as dims.
161
+ rng: A JAX PRNG key.
162
+
163
+ Returns:
164
+ A new pytree with the same structure but with the data reshuffled along
165
+ the specified dimensions.
166
+ """
167
+ if len(dims) != len(shape_dims):
168
+ raise ValueError(f"dims {dims} and shape_dims {shape_dims} must have the same length")
169
+
170
+ def transpose_for_shuffle(x: PyTree) -> TransposeResult:
171
+ if not isinstance(x, jnp.ndarray):
172
+ return x, (), ()
173
+
174
+ # Check if the array has enough dimensions
175
+ if len(x.shape) <= max(dims):
176
+ return x, (), ()
177
+
178
+ # Check if the dimensions match the shape_dims
179
+ for i, dim in enumerate(dims):
180
+ if x.shape[dim] != shape_dims[i]:
181
+ raise ValueError(f"Array shape {x.shape} doesn't match shape_dims {shape_dims} at dimension {dim}")
182
+
183
+ # Create the transpose order to move the specified dimensions to the front
184
+ # while preserving the relative order of the other dimensions
185
+ n_dims = len(x.shape)
186
+ other_dims = [i for i in range(n_dims) if i not in dims]
187
+ transpose_order = tuple(dims) + tuple(other_dims)
188
+
189
+ # Transpose the array
190
+ transposed = jnp.transpose(x, transpose_order)
191
+
192
+ return transposed, transpose_order, x.shape
193
+
194
+ def transpose_back(x: PyTree, transpose_order: tuple[int, ...], original_shape: tuple[int, ...]) -> PyTree:
195
+ if not isinstance(x, jnp.ndarray) or not transpose_order:
196
+ return x
197
+
198
+ # Create the inverse transpose order
199
+ inverse_order = [0] * len(transpose_order)
200
+ for i, j in enumerate(transpose_order):
201
+ inverse_order[j] = i
202
+
203
+ # Transpose back
204
+ return jnp.transpose(x, inverse_order)
205
+
206
+ # First, transpose all arrays to make the specified dimensions the leading ones
207
+ transposed_data: dict[PathType, Array] = {}
208
+ transpose_info: dict[PathType, tuple[tuple[int, ...], tuple[int, ...]]] = {}
209
+
210
+ def prepare_for_shuffle(path: PathType, x: PyTree) -> PyTree:
211
+ if isinstance(x, jnp.ndarray):
212
+ transposed, transpose_order, original_shape = transpose_for_shuffle(x)
213
+ if isinstance(transposed, jnp.ndarray): # Check if it's an array
214
+ transposed_data[path] = transposed
215
+ transpose_info[path] = (transpose_order, original_shape)
216
+ return x
217
+
218
+ jax.tree_util.tree_map_with_path(prepare_for_shuffle, data)
219
+
220
+ # Create a transposed pytree
221
+ def get_transposed(path: PathType, x: PyTree) -> PyTree:
222
+ if isinstance(x, jnp.ndarray) and path in transposed_data:
223
+ return transposed_data[path]
224
+ return x
225
+
226
+ transposed_pytree = jax.tree_util.tree_map_with_path(get_transposed, data)
227
+
228
+ # Reshuffle the transposed pytree along the leading dimensions
229
+ reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
230
+
231
+ # Transpose back
232
+ def restore_transpose(path: PathType, x: PyTree) -> PyTree:
233
+ if isinstance(x, jnp.ndarray) and path in transpose_info:
234
+ transpose_order, original_shape = transpose_info[path]
235
+ return transpose_back(x, transpose_order, original_shape)
236
+ return x
237
+
238
+ return jax.tree_util.tree_map_with_path(restore_transpose, reshuffled_transposed)
xax/utils/tensorboard.py CHANGED
@@ -14,7 +14,7 @@ from PIL.Image import Image as PILImage
14
14
  from tensorboard.compat.proto.config_pb2 import RunMetadata
15
15
  from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
16
16
  from tensorboard.compat.proto.graph_pb2 import GraphDef
17
- from tensorboard.compat.proto.summary_pb2 import Summary, SummaryMetadata
17
+ from tensorboard.compat.proto.summary_pb2 import HistogramProto, Summary, SummaryMetadata
18
18
  from tensorboard.compat.proto.tensor_pb2 import TensorProto
19
19
  from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
20
20
  from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
@@ -25,6 +25,65 @@ from xax.core.state import Phase
25
25
  ImageShape = Literal["HWC", "CHW", "HW", "NHWC", "NCHW", "NHW"]
26
26
 
27
27
 
28
+ def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | None = None) -> HistogramProto:
29
+ """Convert values into a histogram proto using logic from histogram.cc.
30
+
31
+ Args:
32
+ values: Input values to create histogram from
33
+ bins: Bin specification (string like 'auto' or array of bin edges)
34
+ max_bins: Optional maximum number of bins
35
+
36
+ Returns:
37
+ HistogramProto containing the histogram data
38
+ """
39
+ if values.size == 0:
40
+ raise ValueError("The input has no element.")
41
+ values = values.reshape(-1)
42
+ counts, limits = np.histogram(values, bins=bins)
43
+ num_bins = len(counts)
44
+
45
+ if max_bins is not None and num_bins > max_bins:
46
+ subsampling = num_bins // max_bins
47
+ subsampling_remainder = num_bins % subsampling
48
+ if subsampling_remainder != 0:
49
+ counts = np.pad(
50
+ counts,
51
+ pad_width=[[0, subsampling - subsampling_remainder]],
52
+ mode="constant",
53
+ constant_values=0,
54
+ )
55
+ counts = counts.reshape(-1, subsampling).sum(axis=-1)
56
+ new_limits = np.empty((counts.size + 1,), limits.dtype)
57
+ new_limits[:-1] = limits[:-1:subsampling]
58
+ new_limits[-1] = limits[-1]
59
+ limits = new_limits
60
+
61
+ # Find the first and the last bin defining the support of the histogram:
62
+ cum_counts = np.cumsum(np.greater(counts, 0))
63
+ start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
64
+ start = int(start)
65
+ end = int(end) + 1
66
+
67
+ # TensorBoard only includes the right bin limits. To still have the leftmost limit
68
+ # included, we include an empty bin left.
69
+ counts = counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
70
+ limits = limits[start : end + 1]
71
+
72
+ if counts.size == 0 or limits.size == 0:
73
+ raise ValueError("The histogram is empty, please file a bug report.")
74
+
75
+ sum_sq = values.dot(values)
76
+ return HistogramProto(
77
+ min=values.min(),
78
+ max=values.max(),
79
+ num=len(values),
80
+ sum=values.sum(),
81
+ sum_squares=sum_sq,
82
+ bucket_limit=limits.tolist(),
83
+ bucket=counts.tolist(),
84
+ )
85
+
86
+
28
87
  class TensorboardProtobufWriter:
29
88
  def __init__(
30
89
  self,
@@ -263,6 +322,123 @@ class TensorboardWriter:
263
322
  walltime=walltime,
264
323
  )
265
324
 
325
+ def add_histogram(
326
+ self,
327
+ tag: str,
328
+ values: np.ndarray,
329
+ global_step: int | None = None,
330
+ bins: str | np.ndarray = "auto",
331
+ walltime: float | None = None,
332
+ max_bins: int | None = None,
333
+ ) -> None:
334
+ hist = make_histogram(values.astype(float), bins, max_bins)
335
+ self.pb_writer.add_summary(
336
+ Summary(value=[Summary.Value(tag=tag, histo=hist)]),
337
+ global_step=global_step,
338
+ walltime=walltime,
339
+ )
340
+
341
+ def add_histogram_raw(
342
+ self,
343
+ tag: str,
344
+ min: float | np.ndarray,
345
+ max: float | np.ndarray,
346
+ num: int | np.ndarray,
347
+ sum: float | np.ndarray,
348
+ sum_squares: float | np.ndarray,
349
+ bucket_limits: list[float | np.ndarray],
350
+ bucket_counts: list[int | np.ndarray],
351
+ global_step: int | None = None,
352
+ walltime: float | None = None,
353
+ ) -> None:
354
+ """Add histogram with raw data to summary.
355
+
356
+ Args:
357
+ tag: Data identifier
358
+ min: Min value
359
+ max: Max value
360
+ num: Number of values
361
+ sum: Sum of all values
362
+ sum_squares: Sum of squares for all values
363
+ bucket_limits: Upper value per bucket
364
+ bucket_counts: Number of values per bucket
365
+ global_step: Global step value to record
366
+ walltime: Optional override default walltime
367
+ """
368
+ if len(bucket_limits) != len(bucket_counts):
369
+ raise ValueError("len(bucket_limits) != len(bucket_counts)")
370
+
371
+ # Convert numpy arrays to Python types
372
+ hist = HistogramProto(
373
+ min=float(min),
374
+ max=float(max),
375
+ num=int(num),
376
+ sum=float(sum),
377
+ sum_squares=float(sum_squares),
378
+ bucket_limit=[float(x) for x in bucket_limits],
379
+ bucket=[int(x) for x in bucket_counts],
380
+ )
381
+ self.pb_writer.add_summary(
382
+ Summary(value=[Summary.Value(tag=tag, histo=hist)]),
383
+ global_step=global_step,
384
+ walltime=walltime,
385
+ )
386
+
387
+ def add_gaussian_distribution(
388
+ self,
389
+ tag: str,
390
+ mean: float | np.ndarray,
391
+ std: float | np.ndarray,
392
+ bins: int = 16,
393
+ global_step: int | None = None,
394
+ walltime: float | None = None,
395
+ ) -> None:
396
+ """Add a Gaussian distribution to the summary.
397
+
398
+ Args:
399
+ tag: Data identifier
400
+ mean: Mean of the Gaussian distribution
401
+ std: Standard deviation of the Gaussian distribution
402
+ bins: Number of bins to use for the histogram
403
+ global_step: Global step value to record
404
+ walltime: Optional override default walltime
405
+ """
406
+ # Convert numpy arrays to Python types
407
+ mean = float(mean)
408
+ std = float(std)
409
+
410
+ # Create bin edges spanning ±3 standard deviations
411
+ bin_edges = np.linspace(mean - 3 * std, mean + 3 * std, bins + 1)
412
+
413
+ # Calculate the probability density for each bin
414
+ bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
415
+ gaussian_pdf = np.exp(-0.5 * ((bin_centers - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi))
416
+
417
+ # Scale the PDF to represent counts
418
+ num_samples = bins * 1000
419
+ bucket_counts = (gaussian_pdf * num_samples * (bin_edges[1] - bin_edges[0])).astype(int)
420
+
421
+ # Ensure we have at least one count per bin for visualization
422
+ bucket_counts = np.maximum(bucket_counts, 1)
423
+
424
+ # Calculate actual statistics based on the discretized distribution
425
+ total_counts = float(bucket_counts.sum())
426
+ weighted_sum = float((bin_centers * bucket_counts).sum())
427
+ weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
428
+
429
+ self.add_histogram_raw(
430
+ tag=tag,
431
+ min=float(bin_edges[0]),
432
+ max=float(bin_edges[-1]),
433
+ num=int(total_counts),
434
+ sum=weighted_sum,
435
+ sum_squares=weighted_sum_squares,
436
+ bucket_limits=bin_edges[1:].tolist(), # TensorBoard expects right bin edges
437
+ bucket_counts=bucket_counts.tolist(),
438
+ global_step=global_step,
439
+ walltime=walltime,
440
+ )
441
+
266
442
 
267
443
  class TensorboardWriterKwargs(TypedDict):
268
444
  max_queue_size: int
@@ -1,12 +1,13 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.0.7
4
- Summary: The xax project
5
- Home-page: https://github.com/dpshai/xax
3
+ Version: 0.1.1
4
+ Summary: A library for fast Jax experimentation
5
+ Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
+ Requires-Dist: attrs
10
11
  Requires-Dist: jax
11
12
  Requires-Dist: jaxtyping
12
13
  Requires-Dist: equinox
@@ -30,10 +31,28 @@ Requires-Dist: pytest; extra == "dev"
30
31
  Requires-Dist: types-pillow; extra == "dev"
31
32
  Requires-Dist: types-psutil; extra == "dev"
32
33
  Requires-Dist: types-requests; extra == "dev"
34
+ Provides-Extra: export
35
+ Requires-Dist: orbax-export; extra == "export"
36
+ Requires-Dist: tensorflow; extra == "export"
37
+ Provides-Extra: flax
38
+ Requires-Dist: flax; extra == "flax"
39
+ Provides-Extra: all
40
+ Requires-Dist: black; extra == "all"
41
+ Requires-Dist: darglint; extra == "all"
42
+ Requires-Dist: mypy; extra == "all"
43
+ Requires-Dist: ruff; extra == "all"
44
+ Requires-Dist: pytest; extra == "all"
45
+ Requires-Dist: types-pillow; extra == "all"
46
+ Requires-Dist: types-psutil; extra == "all"
47
+ Requires-Dist: types-requests; extra == "all"
48
+ Requires-Dist: orbax-export; extra == "all"
49
+ Requires-Dist: tensorflow; extra == "all"
50
+ Requires-Dist: flax; extra == "all"
33
51
  Dynamic: author
34
52
  Dynamic: description
35
53
  Dynamic: description-content-type
36
54
  Dynamic: home-page
55
+ Dynamic: license-file
37
56
  Dynamic: provides-extra
38
57
  Dynamic: requires-dist
39
58
  Dynamic: requires-python
@@ -1,19 +1,22 @@
1
- xax/__init__.py,sha256=ScTkvKaxgpuKhhs9RINJa2XWCj899ndSYrB3FtScfxw,10509
1
+ xax/__init__.py,sha256=JyKRACir9b0bkuG93bwxADFrVr-Lo76kenDBJtvb_wQ,13280
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
- xax/requirements.txt,sha256=NmU9PNJhfLtNqqtWWf8WqMjgbBPCn_yt8oMGAgS7Fno,291
4
+ xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
5
5
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
7
7
  xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
+ xax/nn/equinox.py,sha256=1Ck6ycz76dhit2LHX4y2lp3WJSPsDuRt7TK7AxxQhww,4837
11
+ xax/nn/export.py,sha256=Do2bLjJTD744mxpQuPYpz8fZ3EIjBLaaZfhp8maNVrg,5303
10
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
11
- xax/nn/geom.py,sha256=MtVar9AdqrJQGIFxcIFHyFnV_fblf9Pc4kQT_gTQASI,2195
13
+ xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
14
+ xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
12
15
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
13
16
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- xax/task/base.py,sha256=LHDmM2c_Ps5cGEzn_QUpmyInD7zJJm3Yt9eSeij2Vus,7297
15
- xax/task/logger.py,sha256=orN1jmM4SIR2EiYk8bNoJZscmhX1FytADBU6p9qpows,29256
16
- xax/task/script.py,sha256=4LyXrpj0V36TjAZT4lvQeiOTqa5U2tommHKwgWDCE24,1025
17
+ xax/task/base.py,sha256=MlH5dTKAiMzFRI5fmXCvL1k8ELbalWMBICeVxmW6k2U,7479
18
+ xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
19
+ xax/task/script.py,sha256=zt36Sobdoer86gXHqc4sMAW7bqZRVl6IEExuQZH2USk,926
17
20
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
18
21
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
22
  xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
@@ -23,33 +26,35 @@ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
23
26
  xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
24
27
  xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
25
28
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
26
- xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
27
- xax/task/loggers/tensorboard.py,sha256=FGW96z77oG0Kf3cO6Zznx5U3kJNzPWcuSkpY4RnbFCo,6909
29
+ xax/task/loggers/stdout.py,sha256=bR0k-PfmFgLfPxLPb4hZw_8G_msA32UeHfAAu11nEYs,6757
30
+ xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
28
31
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
29
- xax/task/mixins/artifacts.py,sha256=1H7ZbR-KSsXhVtqGVlqMi-TXfn1-dM7YnTCLVuw594s,3835
30
- xax/task/mixins/checkpointing.py,sha256=AMlobojybvJdDZcNCxm1DHSVC_2Qvnu_MbRcsc_8eoA,8508
32
+ xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
33
+ xax/task/mixins/checkpointing.py,sha256=sRkVxJbQfqDf1-lp1KFrAGYWHhTlV8_DORxGQ_69P1A,8954
31
34
  xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
32
35
  xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
33
36
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
34
37
  xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
35
- xax/task/mixins/logger.py,sha256=CIQ4w4K3FcxN6A9xUfITdVkulSxPa4iaTe6cbs9ruaM,1958
38
+ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
36
39
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
37
40
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
38
- xax/task/mixins/step_wrapper.py,sha256=DJw42mUGwgKx2tkeqatKR9_F4J8ug4wmxKMeJPmhcVQ,1560
39
- xax/task/mixins/train.py,sha256=dhGL_IuDaJy39BooYlO7JO-_EotKldtBhBplDGU_AnM,21745
41
+ xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
+ xax/task/mixins/train.py,sha256=BEC7HSwBlGZDe7jCsedqEA8-K1Zx52-bTjsBONYIE5g,22225
40
43
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
- xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
44
+ xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
+ xax/utils/experiments.py,sha256=_cwoBaiBxoQ_Tstm0rz7TEqfELqcktmPflb6AP1K0qA,28779
42
46
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
43
- xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
47
+ xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
48
+ xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
44
49
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
45
50
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
46
- xax/utils/pytree.py,sha256=Jwx6ErJfv1r2D23D4eKz1Hoo3mAJ0SEqC3EagZarWkw,1858
47
- xax/utils/tensorboard.py,sha256=oGq2E3Yr0z2xaACv2UOVt_CHEVc8fBxI8V1M99Fd34E,9742
51
+ xax/utils/pytree.py,sha256=7GjQoPc_ZSZt3QS_9qXoBWl1jfMp1qZa7aViQoWJ0OQ,8864
52
+ xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,16170
48
53
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
49
54
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
55
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
51
- xax-0.0.7.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
52
- xax-0.0.7.dist-info/METADATA,sha256=hE0KO4kYcN6Ed8iZ4649R5ENOUaQysBMW9vTh-94d4I,1171
53
- xax-0.0.7.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
54
- xax-0.0.7.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
55
- xax-0.0.7.dist-info/RECORD,,
56
+ xax-0.1.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
57
+ xax-0.1.1.dist-info/METADATA,sha256=tJ4ilL3uBbykHBQTHbh-bN6m4hrHqivyyFeuI33ddX4,1877
58
+ xax-0.1.1.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
59
+ xax-0.1.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
60
+ xax-0.1.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (76.0.0)
2
+ Generator: setuptools (77.0.3)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5