xax 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/utils/pytree.py ADDED
@@ -0,0 +1,238 @@
1
+ """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
+
3
+ import chex
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from jax import Array
7
+ from jaxtyping import PRNGKeyArray, PyTree
8
+
9
+
10
+ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
11
+ """Get a slice of an array along the first dimension.
12
+
13
+ For multi-dimensional arrays, this slices only along the first dimension
14
+ and keeps all other dimensions intact.
15
+
16
+ Args:
17
+ x: The array to slice.
18
+ start: The start index of the slice.
19
+ slice_length: The length of the slice.
20
+
21
+ Returns:
22
+ The sliced array.
23
+ """
24
+ chex.assert_shape(start, ())
25
+ chex.assert_shape(slice_length, ())
26
+ start_indices = (start,) + (0,) * (len(x.shape) - 1)
27
+ slice_sizes = (slice_length,) + x.shape[1:]
28
+
29
+ return jax.lax.dynamic_slice(x, start_indices, slice_sizes)
30
+
31
+
32
+ def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
33
+ """Get a slice of a pytree."""
34
+ return jax.tree_util.tree_map(lambda x: slice_array(x, start, slice_length), pytree)
35
+
36
+
37
+ def flatten_array(x: Array, flatten_size: int) -> Array:
38
+ """Flatten an array into a (flatten_size, ...) array."""
39
+ reshaped = jnp.reshape(x, (flatten_size, *x.shape[2:]))
40
+ assert reshaped.shape[0] == flatten_size
41
+ return reshaped
42
+
43
+
44
+ def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
45
+ """Flatten a pytree into a (flatten_size, ...) pytree."""
46
+ return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
47
+
48
+
49
+ def pytree_has_nans(pytree: PyTree) -> Array:
50
+ """Check if a pytree has any NaNs."""
51
+ has_nans = jax.tree_util.tree_reduce(
52
+ lambda a, b: jnp.logical_or(a, b),
53
+ jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree),
54
+ )
55
+ return has_nans
56
+
57
+
58
+ def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
59
+ """Update a pytree based on a condition."""
60
+ # Tricky, need use tree_map because where expects array leafs.
61
+ return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), new, original)
62
+
63
+
64
+ def compute_nan_ratio(pytree: PyTree) -> Array:
65
+ """Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
66
+ nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
67
+ total_counts = jax.tree_util.tree_map(lambda x: x.size, pytree)
68
+
69
+ total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
70
+ total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
71
+ overall_nan_ratio = jnp.array(total_nans / total_elements)
72
+
73
+ return overall_nan_ratio
74
+
75
+
76
+ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
77
+ """Reshuffle a pytree along the leading dimensions.
78
+
79
+ This function reshuffles the data along the leading dimensions specified by batch_shape.
80
+ Assumes the dimensions to shuffle are the leading ones.
81
+
82
+ Args:
83
+ data: A pytree with arrays.
84
+ batch_shape: A tuple of integers specifying the size of each leading dimension to shuffle.
85
+ rng: A JAX PRNG key.
86
+
87
+ Returns:
88
+ A new pytree with the same structure but with the data reshuffled along
89
+ the leading dimensions.
90
+ """
91
+ # Create a permutation for the flattened batch dimensions
92
+ flat_size = 1
93
+ for dim in batch_shape:
94
+ flat_size *= dim
95
+
96
+ perm = jax.random.permutation(rng, flat_size)
97
+
98
+ def permute_array(x: Array) -> Array:
99
+ if not isinstance(x, jnp.ndarray):
100
+ return x
101
+
102
+ # Check if the array has enough dimensions
103
+ if len(x.shape) < len(batch_shape):
104
+ return x
105
+
106
+ # Check if the dimensions match the batch_shape
107
+ for i, dim in enumerate(batch_shape):
108
+ if x.shape[i] != dim:
109
+ return x
110
+
111
+ # Reshape to flatten the batch dimensions
112
+ orig_shape = x.shape
113
+ reshaped = x.reshape((flat_size,) + orig_shape[len(batch_shape) :])
114
+
115
+ # Apply the permutation
116
+ permuted = reshaped[perm]
117
+
118
+ # Reshape back to the original shape
119
+ return permuted.reshape(orig_shape)
120
+
121
+ return jax.tree_util.tree_map(permute_array, data)
122
+
123
+
124
+ def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
125
+ """Reshuffle a rollout array across arbitrary batch dimensions independently of each other."""
126
+ rngs = jax.random.split(rng, len(batch_shape))
127
+ perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape)]
128
+ # n-dimensional index grid from permutations
129
+ idx_grids = jnp.meshgrid(*perms, indexing="ij")
130
+
131
+ def permute_array(x: Array) -> Array:
132
+ if isinstance(x, Array):
133
+ return x[tuple(idx_grids)]
134
+ return x
135
+
136
+ return jax.tree_util.tree_map(permute_array, data)
137
+
138
+
139
+ TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
140
+ PathType = tuple[str | int, ...]
141
+
142
+
143
+ def reshuffle_pytree_along_dims(
144
+ data: PyTree,
145
+ dims: tuple[int, ...],
146
+ shape_dims: tuple[int, ...],
147
+ rng: PRNGKeyArray,
148
+ ) -> PyTree:
149
+ """Reshuffle a pytree along arbitrary dimensions.
150
+
151
+ Allows reshuffling along any dimensions, not just the leading ones.
152
+ It transposes the data to make the specified dimensions the leading ones,
153
+ then reshuffles, and finally transposes back.
154
+
155
+ Args:
156
+ data: A pytree with arrays.
157
+ dims: A tuple of integers specifying which dimensions to shuffle along.
158
+ For example, (1,) would shuffle along the second dimension.
159
+ shape_dims: A tuple of integers specifying the size of each dimension to shuffle.
160
+ Must have the same length as dims.
161
+ rng: A JAX PRNG key.
162
+
163
+ Returns:
164
+ A new pytree with the same structure but with the data reshuffled along
165
+ the specified dimensions.
166
+ """
167
+ if len(dims) != len(shape_dims):
168
+ raise ValueError(f"dims {dims} and shape_dims {shape_dims} must have the same length")
169
+
170
+ def transpose_for_shuffle(x: PyTree) -> TransposeResult:
171
+ if not isinstance(x, jnp.ndarray):
172
+ return x, (), ()
173
+
174
+ # Check if the array has enough dimensions
175
+ if len(x.shape) <= max(dims):
176
+ return x, (), ()
177
+
178
+ # Check if the dimensions match the shape_dims
179
+ for i, dim in enumerate(dims):
180
+ if x.shape[dim] != shape_dims[i]:
181
+ raise ValueError(f"Array shape {x.shape} doesn't match shape_dims {shape_dims} at dimension {dim}")
182
+
183
+ # Create the transpose order to move the specified dimensions to the front
184
+ # while preserving the relative order of the other dimensions
185
+ n_dims = len(x.shape)
186
+ other_dims = [i for i in range(n_dims) if i not in dims]
187
+ transpose_order = tuple(dims) + tuple(other_dims)
188
+
189
+ # Transpose the array
190
+ transposed = jnp.transpose(x, transpose_order)
191
+
192
+ return transposed, transpose_order, x.shape
193
+
194
+ def transpose_back(x: PyTree, transpose_order: tuple[int, ...], original_shape: tuple[int, ...]) -> PyTree:
195
+ if not isinstance(x, jnp.ndarray) or not transpose_order:
196
+ return x
197
+
198
+ # Create the inverse transpose order
199
+ inverse_order = [0] * len(transpose_order)
200
+ for i, j in enumerate(transpose_order):
201
+ inverse_order[j] = i
202
+
203
+ # Transpose back
204
+ return jnp.transpose(x, inverse_order)
205
+
206
+ # First, transpose all arrays to make the specified dimensions the leading ones
207
+ transposed_data: dict[PathType, Array] = {}
208
+ transpose_info: dict[PathType, tuple[tuple[int, ...], tuple[int, ...]]] = {}
209
+
210
+ def prepare_for_shuffle(path: PathType, x: PyTree) -> PyTree:
211
+ if isinstance(x, jnp.ndarray):
212
+ transposed, transpose_order, original_shape = transpose_for_shuffle(x)
213
+ if isinstance(transposed, jnp.ndarray): # Check if it's an array
214
+ transposed_data[path] = transposed
215
+ transpose_info[path] = (transpose_order, original_shape)
216
+ return x
217
+
218
+ jax.tree_util.tree_map_with_path(prepare_for_shuffle, data)
219
+
220
+ # Create a transposed pytree
221
+ def get_transposed(path: PathType, x: PyTree) -> PyTree:
222
+ if isinstance(x, jnp.ndarray) and path in transposed_data:
223
+ return transposed_data[path]
224
+ return x
225
+
226
+ transposed_pytree = jax.tree_util.tree_map_with_path(get_transposed, data)
227
+
228
+ # Reshuffle the transposed pytree along the leading dimensions
229
+ reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
230
+
231
+ # Transpose back
232
+ def restore_transpose(path: PathType, x: PyTree) -> PyTree:
233
+ if isinstance(x, jnp.ndarray) and path in transpose_info:
234
+ transpose_order, original_shape = transpose_info[path]
235
+ return transpose_back(x, transpose_order, original_shape)
236
+ return x
237
+
238
+ return jax.tree_util.tree_map_with_path(restore_transpose, reshuffled_transposed)
xax/utils/tensorboard.py CHANGED
@@ -14,7 +14,7 @@ from PIL.Image import Image as PILImage
14
14
  from tensorboard.compat.proto.config_pb2 import RunMetadata
15
15
  from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
16
16
  from tensorboard.compat.proto.graph_pb2 import GraphDef
17
- from tensorboard.compat.proto.summary_pb2 import Summary, SummaryMetadata
17
+ from tensorboard.compat.proto.summary_pb2 import HistogramProto, Summary, SummaryMetadata
18
18
  from tensorboard.compat.proto.tensor_pb2 import TensorProto
19
19
  from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
20
20
  from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
@@ -25,6 +25,65 @@ from xax.core.state import Phase
25
25
  ImageShape = Literal["HWC", "CHW", "HW", "NHWC", "NCHW", "NHW"]
26
26
 
27
27
 
28
+ def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | None = None) -> HistogramProto:
29
+ """Convert values into a histogram proto using logic from histogram.cc.
30
+
31
+ Args:
32
+ values: Input values to create histogram from
33
+ bins: Bin specification (string like 'auto' or array of bin edges)
34
+ max_bins: Optional maximum number of bins
35
+
36
+ Returns:
37
+ HistogramProto containing the histogram data
38
+ """
39
+ if values.size == 0:
40
+ raise ValueError("The input has no element.")
41
+ values = values.reshape(-1)
42
+ counts, limits = np.histogram(values, bins=bins)
43
+ num_bins = len(counts)
44
+
45
+ if max_bins is not None and num_bins > max_bins:
46
+ subsampling = num_bins // max_bins
47
+ subsampling_remainder = num_bins % subsampling
48
+ if subsampling_remainder != 0:
49
+ counts = np.pad(
50
+ counts,
51
+ pad_width=[[0, subsampling - subsampling_remainder]],
52
+ mode="constant",
53
+ constant_values=0,
54
+ )
55
+ counts = counts.reshape(-1, subsampling).sum(axis=-1)
56
+ new_limits = np.empty((counts.size + 1,), limits.dtype)
57
+ new_limits[:-1] = limits[:-1:subsampling]
58
+ new_limits[-1] = limits[-1]
59
+ limits = new_limits
60
+
61
+ # Find the first and the last bin defining the support of the histogram:
62
+ cum_counts = np.cumsum(np.greater(counts, 0))
63
+ start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
64
+ start = int(start)
65
+ end = int(end) + 1
66
+
67
+ # TensorBoard only includes the right bin limits. To still have the leftmost limit
68
+ # included, we include an empty bin left.
69
+ counts = counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
70
+ limits = limits[start : end + 1]
71
+
72
+ if counts.size == 0 or limits.size == 0:
73
+ raise ValueError("The histogram is empty, please file a bug report.")
74
+
75
+ sum_sq = values.dot(values)
76
+ return HistogramProto(
77
+ min=values.min(),
78
+ max=values.max(),
79
+ num=len(values),
80
+ sum=values.sum(),
81
+ sum_squares=sum_sq,
82
+ bucket_limit=limits.tolist(),
83
+ bucket=counts.tolist(),
84
+ )
85
+
86
+
28
87
  class TensorboardProtobufWriter:
29
88
  def __init__(
30
89
  self,
@@ -263,6 +322,123 @@ class TensorboardWriter:
263
322
  walltime=walltime,
264
323
  )
265
324
 
325
+ def add_histogram(
326
+ self,
327
+ tag: str,
328
+ values: np.ndarray,
329
+ global_step: int | None = None,
330
+ bins: str | np.ndarray = "auto",
331
+ walltime: float | None = None,
332
+ max_bins: int | None = None,
333
+ ) -> None:
334
+ hist = make_histogram(values.astype(float), bins, max_bins)
335
+ self.pb_writer.add_summary(
336
+ Summary(value=[Summary.Value(tag=tag, histo=hist)]),
337
+ global_step=global_step,
338
+ walltime=walltime,
339
+ )
340
+
341
+ def add_histogram_raw(
342
+ self,
343
+ tag: str,
344
+ min: float | np.ndarray,
345
+ max: float | np.ndarray,
346
+ num: int | np.ndarray,
347
+ sum: float | np.ndarray,
348
+ sum_squares: float | np.ndarray,
349
+ bucket_limits: list[float | np.ndarray],
350
+ bucket_counts: list[int | np.ndarray],
351
+ global_step: int | None = None,
352
+ walltime: float | None = None,
353
+ ) -> None:
354
+ """Add histogram with raw data to summary.
355
+
356
+ Args:
357
+ tag: Data identifier
358
+ min: Min value
359
+ max: Max value
360
+ num: Number of values
361
+ sum: Sum of all values
362
+ sum_squares: Sum of squares for all values
363
+ bucket_limits: Upper value per bucket
364
+ bucket_counts: Number of values per bucket
365
+ global_step: Global step value to record
366
+ walltime: Optional override default walltime
367
+ """
368
+ if len(bucket_limits) != len(bucket_counts):
369
+ raise ValueError("len(bucket_limits) != len(bucket_counts)")
370
+
371
+ # Convert numpy arrays to Python types
372
+ hist = HistogramProto(
373
+ min=float(min),
374
+ max=float(max),
375
+ num=int(num),
376
+ sum=float(sum),
377
+ sum_squares=float(sum_squares),
378
+ bucket_limit=[float(x) for x in bucket_limits],
379
+ bucket=[int(x) for x in bucket_counts],
380
+ )
381
+ self.pb_writer.add_summary(
382
+ Summary(value=[Summary.Value(tag=tag, histo=hist)]),
383
+ global_step=global_step,
384
+ walltime=walltime,
385
+ )
386
+
387
+ def add_gaussian_distribution(
388
+ self,
389
+ tag: str,
390
+ mean: float | np.ndarray,
391
+ std: float | np.ndarray,
392
+ bins: int = 16,
393
+ global_step: int | None = None,
394
+ walltime: float | None = None,
395
+ ) -> None:
396
+ """Add a Gaussian distribution to the summary.
397
+
398
+ Args:
399
+ tag: Data identifier
400
+ mean: Mean of the Gaussian distribution
401
+ std: Standard deviation of the Gaussian distribution
402
+ bins: Number of bins to use for the histogram
403
+ global_step: Global step value to record
404
+ walltime: Optional override default walltime
405
+ """
406
+ # Convert numpy arrays to Python types
407
+ mean = float(mean)
408
+ std = float(std)
409
+
410
+ # Create bin edges spanning ±3 standard deviations
411
+ bin_edges = np.linspace(mean - 3 * std, mean + 3 * std, bins + 1)
412
+
413
+ # Calculate the probability density for each bin
414
+ bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
415
+ gaussian_pdf = np.exp(-0.5 * ((bin_centers - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi))
416
+
417
+ # Scale the PDF to represent counts
418
+ num_samples = bins * 1000
419
+ bucket_counts = (gaussian_pdf * num_samples * (bin_edges[1] - bin_edges[0])).astype(int)
420
+
421
+ # Ensure we have at least one count per bin for visualization
422
+ bucket_counts = np.maximum(bucket_counts, 1)
423
+
424
+ # Calculate actual statistics based on the discretized distribution
425
+ total_counts = float(bucket_counts.sum())
426
+ weighted_sum = float((bin_centers * bucket_counts).sum())
427
+ weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
428
+
429
+ self.add_histogram_raw(
430
+ tag=tag,
431
+ min=float(bin_edges[0]),
432
+ max=float(bin_edges[-1]),
433
+ num=int(total_counts),
434
+ sum=weighted_sum,
435
+ sum_squares=weighted_sum_squares,
436
+ bucket_limits=bin_edges[1:].tolist(), # TensorBoard expects right bin edges
437
+ bucket_counts=bucket_counts.tolist(),
438
+ global_step=global_step,
439
+ walltime=walltime,
440
+ )
441
+
266
442
 
267
443
  class TensorboardWriterKwargs(TypedDict):
268
444
  max_queue_size: int
@@ -1,12 +1,13 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.0.6
4
- Summary: The xax project
5
- Home-page: https://github.com/dpshai/xax
3
+ Version: 0.1.0
4
+ Summary: A library for fast Jax experimentation
5
+ Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
+ Requires-Dist: attrs
10
11
  Requires-Dist: jax
11
12
  Requires-Dist: jaxtyping
12
13
  Requires-Dist: equinox
@@ -30,10 +31,28 @@ Requires-Dist: pytest; extra == "dev"
30
31
  Requires-Dist: types-pillow; extra == "dev"
31
32
  Requires-Dist: types-psutil; extra == "dev"
32
33
  Requires-Dist: types-requests; extra == "dev"
34
+ Provides-Extra: export
35
+ Requires-Dist: orbax-export; extra == "export"
36
+ Requires-Dist: tensorflow; extra == "export"
37
+ Provides-Extra: flax
38
+ Requires-Dist: flax; extra == "flax"
39
+ Provides-Extra: all
40
+ Requires-Dist: black; extra == "all"
41
+ Requires-Dist: darglint; extra == "all"
42
+ Requires-Dist: mypy; extra == "all"
43
+ Requires-Dist: ruff; extra == "all"
44
+ Requires-Dist: pytest; extra == "all"
45
+ Requires-Dist: types-pillow; extra == "all"
46
+ Requires-Dist: types-psutil; extra == "all"
47
+ Requires-Dist: types-requests; extra == "all"
48
+ Requires-Dist: orbax-export; extra == "all"
49
+ Requires-Dist: tensorflow; extra == "all"
50
+ Requires-Dist: flax; extra == "all"
33
51
  Dynamic: author
34
52
  Dynamic: description
35
53
  Dynamic: description-content-type
36
54
  Dynamic: home-page
55
+ Dynamic: license-file
37
56
  Dynamic: provides-extra
38
57
  Dynamic: requires-dist
39
58
  Dynamic: requires-python
@@ -1,18 +1,22 @@
1
- xax/__init__.py,sha256=RTUsDh_R0TFa09q-_U0vd-eCYRC-bCaHqHlayp8U2hU,9736
1
+ xax/__init__.py,sha256=psnt49vcnlzZ7llojgBqCEP0ZquHws_8tZpAZ-5vvLE,13280
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
- xax/requirements.txt,sha256=NmU9PNJhfLtNqqtWWf8WqMjgbBPCn_yt8oMGAgS7Fno,291
4
+ xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
5
5
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
7
7
  xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
+ xax/nn/equinox.py,sha256=1Ck6ycz76dhit2LHX4y2lp3WJSPsDuRt7TK7AxxQhww,4837
11
+ xax/nn/export.py,sha256=Do2bLjJTD744mxpQuPYpz8fZ3EIjBLaaZfhp8maNVrg,5303
10
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
+ xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
14
+ xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
11
15
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
12
16
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- xax/task/base.py,sha256=LHDmM2c_Ps5cGEzn_QUpmyInD7zJJm3Yt9eSeij2Vus,7297
14
- xax/task/logger.py,sha256=orN1jmM4SIR2EiYk8bNoJZscmhX1FytADBU6p9qpows,29256
15
- xax/task/script.py,sha256=4LyXrpj0V36TjAZT4lvQeiOTqa5U2tommHKwgWDCE24,1025
17
+ xax/task/base.py,sha256=MlH5dTKAiMzFRI5fmXCvL1k8ELbalWMBICeVxmW6k2U,7479
18
+ xax/task/logger.py,sha256=z5tcFWsoQkDExCZSL7k_ub-a2RfxmGGGn0WOnYBG82Y,32631
19
+ xax/task/script.py,sha256=zt36Sobdoer86gXHqc4sMAW7bqZRVl6IEExuQZH2USk,926
16
20
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
17
21
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
22
  xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
@@ -22,31 +26,35 @@ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
22
26
  xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
23
27
  xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
24
28
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
25
- xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
26
- xax/task/loggers/tensorboard.py,sha256=FGW96z77oG0Kf3cO6Zznx5U3kJNzPWcuSkpY4RnbFCo,6909
29
+ xax/task/loggers/stdout.py,sha256=bR0k-PfmFgLfPxLPb4hZw_8G_msA32UeHfAAu11nEYs,6757
30
+ xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
27
31
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
28
- xax/task/mixins/artifacts.py,sha256=1H7ZbR-KSsXhVtqGVlqMi-TXfn1-dM7YnTCLVuw594s,3835
29
- xax/task/mixins/checkpointing.py,sha256=AMlobojybvJdDZcNCxm1DHSVC_2Qvnu_MbRcsc_8eoA,8508
32
+ xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
33
+ xax/task/mixins/checkpointing.py,sha256=sRkVxJbQfqDf1-lp1KFrAGYWHhTlV8_DORxGQ_69P1A,8954
30
34
  xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
31
35
  xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
32
36
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
33
37
  xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
34
- xax/task/mixins/logger.py,sha256=CIQ4w4K3FcxN6A9xUfITdVkulSxPa4iaTe6cbs9ruaM,1958
38
+ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
35
39
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
36
40
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
37
- xax/task/mixins/step_wrapper.py,sha256=DJw42mUGwgKx2tkeqatKR9_F4J8ug4wmxKMeJPmhcVQ,1560
38
- xax/task/mixins/train.py,sha256=dhGL_IuDaJy39BooYlO7JO-_EotKldtBhBplDGU_AnM,21745
41
+ xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
+ xax/task/mixins/train.py,sha256=6IW1gNnE1a92E1h6SThmnGS9dyjqw0KrWzx0SG6n0_Q,22318
39
43
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
- xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
41
- xax/utils/jax.py,sha256=VzEVB766UyH3_cgN6UP0FkCsDuGlYg5KJj8YJS4yYUk,439
44
+ xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
+ xax/utils/experiments.py,sha256=_cwoBaiBxoQ_Tstm0rz7TEqfELqcktmPflb6AP1K0qA,28779
46
+ xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
47
+ xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
42
48
  xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
43
49
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
44
- xax/utils/tensorboard.py,sha256=oGq2E3Yr0z2xaACv2UOVt_CHEVc8fBxI8V1M99Fd34E,9742
50
+ xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
51
+ xax/utils/pytree.py,sha256=7GjQoPc_ZSZt3QS_9qXoBWl1jfMp1qZa7aViQoWJ0OQ,8864
52
+ xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,16170
45
53
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
46
54
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
55
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
48
- xax-0.0.6.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
49
- xax-0.0.6.dist-info/METADATA,sha256=YO2c2PUMWkH1ILfPhFWKK4Sodbo9qUpUOCIkm4aLHfg,1171
50
- xax-0.0.6.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
51
- xax-0.0.6.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
52
- xax-0.0.6.dist-info/RECORD,,
56
+ xax-0.1.0.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
57
+ xax-0.1.0.dist-info/METADATA,sha256=b5q3AVoywNcDoTLcsf5mCk3wSmL4BHDAudewxmY1XJw,1877
58
+ xax-0.1.0.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
59
+ xax-0.1.0.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
60
+ xax-0.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.2)
2
+ Generator: setuptools (77.0.3)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5