xax 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +94 -4
- xax/nn/equinox.py +180 -0
- xax/nn/export.py +147 -0
- xax/nn/geom.py +26 -0
- xax/nn/norm.py +23 -0
- xax/requirements.txt +1 -0
- xax/task/base.py +6 -0
- xax/task/logger.py +220 -44
- xax/task/loggers/stdout.py +2 -2
- xax/task/loggers/tensorboard.py +25 -14
- xax/task/mixins/artifacts.py +1 -21
- xax/task/mixins/checkpointing.py +19 -5
- xax/task/mixins/logger.py +28 -4
- xax/task/mixins/step_wrapper.py +23 -32
- xax/task/mixins/train.py +50 -35
- xax/task/script.py +0 -4
- xax/utils/debugging.py +49 -0
- xax/utils/experiments.py +23 -4
- xax/utils/jaxpr.py +77 -0
- xax/utils/logging.py +12 -2
- xax/utils/pytree.py +189 -1
- xax/utils/tensorboard.py +177 -1
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/METADATA +23 -4
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/RECORD +27 -22
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/WHEEL +1 -1
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {xax-0.0.7.dist-info → xax-0.1.1.dist-info}/top_level.txt +0 -0
xax/utils/pytree.py
CHANGED
@@ -4,7 +4,7 @@ import chex
|
|
4
4
|
import jax
|
5
5
|
import jax.numpy as jnp
|
6
6
|
from jax import Array
|
7
|
-
from jaxtyping import PyTree
|
7
|
+
from jaxtyping import PRNGKeyArray, PyTree
|
8
8
|
|
9
9
|
|
10
10
|
def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
@@ -12,6 +12,14 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
|
12
12
|
|
13
13
|
For multi-dimensional arrays, this slices only along the first dimension
|
14
14
|
and keeps all other dimensions intact.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
x: The array to slice.
|
18
|
+
start: The start index of the slice.
|
19
|
+
slice_length: The length of the slice.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
The sliced array.
|
15
23
|
"""
|
16
24
|
chex.assert_shape(start, ())
|
17
25
|
chex.assert_shape(slice_length, ())
|
@@ -38,6 +46,21 @@ def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
|
|
38
46
|
return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
|
39
47
|
|
40
48
|
|
49
|
+
def pytree_has_nans(pytree: PyTree) -> Array:
|
50
|
+
"""Check if a pytree has any NaNs."""
|
51
|
+
has_nans = jax.tree_util.tree_reduce(
|
52
|
+
lambda a, b: jnp.logical_or(a, b),
|
53
|
+
jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree),
|
54
|
+
)
|
55
|
+
return has_nans
|
56
|
+
|
57
|
+
|
58
|
+
def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
|
59
|
+
"""Update a pytree based on a condition."""
|
60
|
+
# Tricky, need use tree_map because where expects array leafs.
|
61
|
+
return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), new, original)
|
62
|
+
|
63
|
+
|
41
64
|
def compute_nan_ratio(pytree: PyTree) -> Array:
|
42
65
|
"""Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
|
43
66
|
nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
|
@@ -48,3 +71,168 @@ def compute_nan_ratio(pytree: PyTree) -> Array:
|
|
48
71
|
overall_nan_ratio = jnp.array(total_nans / total_elements)
|
49
72
|
|
50
73
|
return overall_nan_ratio
|
74
|
+
|
75
|
+
|
76
|
+
def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
|
77
|
+
"""Reshuffle a pytree along the leading dimensions.
|
78
|
+
|
79
|
+
This function reshuffles the data along the leading dimensions specified by batch_shape.
|
80
|
+
Assumes the dimensions to shuffle are the leading ones.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
data: A pytree with arrays.
|
84
|
+
batch_shape: A tuple of integers specifying the size of each leading dimension to shuffle.
|
85
|
+
rng: A JAX PRNG key.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
A new pytree with the same structure but with the data reshuffled along
|
89
|
+
the leading dimensions.
|
90
|
+
"""
|
91
|
+
# Create a permutation for the flattened batch dimensions
|
92
|
+
flat_size = 1
|
93
|
+
for dim in batch_shape:
|
94
|
+
flat_size *= dim
|
95
|
+
|
96
|
+
perm = jax.random.permutation(rng, flat_size)
|
97
|
+
|
98
|
+
def permute_array(x: Array) -> Array:
|
99
|
+
if not isinstance(x, jnp.ndarray):
|
100
|
+
return x
|
101
|
+
|
102
|
+
# Check if the array has enough dimensions
|
103
|
+
if len(x.shape) < len(batch_shape):
|
104
|
+
return x
|
105
|
+
|
106
|
+
# Check if the dimensions match the batch_shape
|
107
|
+
for i, dim in enumerate(batch_shape):
|
108
|
+
if x.shape[i] != dim:
|
109
|
+
return x
|
110
|
+
|
111
|
+
# Reshape to flatten the batch dimensions
|
112
|
+
orig_shape = x.shape
|
113
|
+
reshaped = x.reshape((flat_size,) + orig_shape[len(batch_shape) :])
|
114
|
+
|
115
|
+
# Apply the permutation
|
116
|
+
permuted = reshaped[perm]
|
117
|
+
|
118
|
+
# Reshape back to the original shape
|
119
|
+
return permuted.reshape(orig_shape)
|
120
|
+
|
121
|
+
return jax.tree_util.tree_map(permute_array, data)
|
122
|
+
|
123
|
+
|
124
|
+
def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
|
125
|
+
"""Reshuffle a rollout array across arbitrary batch dimensions independently of each other."""
|
126
|
+
rngs = jax.random.split(rng, len(batch_shape))
|
127
|
+
perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape)]
|
128
|
+
# n-dimensional index grid from permutations
|
129
|
+
idx_grids = jnp.meshgrid(*perms, indexing="ij")
|
130
|
+
|
131
|
+
def permute_array(x: Array) -> Array:
|
132
|
+
if isinstance(x, Array):
|
133
|
+
return x[tuple(idx_grids)]
|
134
|
+
return x
|
135
|
+
|
136
|
+
return jax.tree_util.tree_map(permute_array, data)
|
137
|
+
|
138
|
+
|
139
|
+
TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
|
140
|
+
PathType = tuple[str | int, ...]
|
141
|
+
|
142
|
+
|
143
|
+
def reshuffle_pytree_along_dims(
|
144
|
+
data: PyTree,
|
145
|
+
dims: tuple[int, ...],
|
146
|
+
shape_dims: tuple[int, ...],
|
147
|
+
rng: PRNGKeyArray,
|
148
|
+
) -> PyTree:
|
149
|
+
"""Reshuffle a pytree along arbitrary dimensions.
|
150
|
+
|
151
|
+
Allows reshuffling along any dimensions, not just the leading ones.
|
152
|
+
It transposes the data to make the specified dimensions the leading ones,
|
153
|
+
then reshuffles, and finally transposes back.
|
154
|
+
|
155
|
+
Args:
|
156
|
+
data: A pytree with arrays.
|
157
|
+
dims: A tuple of integers specifying which dimensions to shuffle along.
|
158
|
+
For example, (1,) would shuffle along the second dimension.
|
159
|
+
shape_dims: A tuple of integers specifying the size of each dimension to shuffle.
|
160
|
+
Must have the same length as dims.
|
161
|
+
rng: A JAX PRNG key.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
A new pytree with the same structure but with the data reshuffled along
|
165
|
+
the specified dimensions.
|
166
|
+
"""
|
167
|
+
if len(dims) != len(shape_dims):
|
168
|
+
raise ValueError(f"dims {dims} and shape_dims {shape_dims} must have the same length")
|
169
|
+
|
170
|
+
def transpose_for_shuffle(x: PyTree) -> TransposeResult:
|
171
|
+
if not isinstance(x, jnp.ndarray):
|
172
|
+
return x, (), ()
|
173
|
+
|
174
|
+
# Check if the array has enough dimensions
|
175
|
+
if len(x.shape) <= max(dims):
|
176
|
+
return x, (), ()
|
177
|
+
|
178
|
+
# Check if the dimensions match the shape_dims
|
179
|
+
for i, dim in enumerate(dims):
|
180
|
+
if x.shape[dim] != shape_dims[i]:
|
181
|
+
raise ValueError(f"Array shape {x.shape} doesn't match shape_dims {shape_dims} at dimension {dim}")
|
182
|
+
|
183
|
+
# Create the transpose order to move the specified dimensions to the front
|
184
|
+
# while preserving the relative order of the other dimensions
|
185
|
+
n_dims = len(x.shape)
|
186
|
+
other_dims = [i for i in range(n_dims) if i not in dims]
|
187
|
+
transpose_order = tuple(dims) + tuple(other_dims)
|
188
|
+
|
189
|
+
# Transpose the array
|
190
|
+
transposed = jnp.transpose(x, transpose_order)
|
191
|
+
|
192
|
+
return transposed, transpose_order, x.shape
|
193
|
+
|
194
|
+
def transpose_back(x: PyTree, transpose_order: tuple[int, ...], original_shape: tuple[int, ...]) -> PyTree:
|
195
|
+
if not isinstance(x, jnp.ndarray) or not transpose_order:
|
196
|
+
return x
|
197
|
+
|
198
|
+
# Create the inverse transpose order
|
199
|
+
inverse_order = [0] * len(transpose_order)
|
200
|
+
for i, j in enumerate(transpose_order):
|
201
|
+
inverse_order[j] = i
|
202
|
+
|
203
|
+
# Transpose back
|
204
|
+
return jnp.transpose(x, inverse_order)
|
205
|
+
|
206
|
+
# First, transpose all arrays to make the specified dimensions the leading ones
|
207
|
+
transposed_data: dict[PathType, Array] = {}
|
208
|
+
transpose_info: dict[PathType, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
209
|
+
|
210
|
+
def prepare_for_shuffle(path: PathType, x: PyTree) -> PyTree:
|
211
|
+
if isinstance(x, jnp.ndarray):
|
212
|
+
transposed, transpose_order, original_shape = transpose_for_shuffle(x)
|
213
|
+
if isinstance(transposed, jnp.ndarray): # Check if it's an array
|
214
|
+
transposed_data[path] = transposed
|
215
|
+
transpose_info[path] = (transpose_order, original_shape)
|
216
|
+
return x
|
217
|
+
|
218
|
+
jax.tree_util.tree_map_with_path(prepare_for_shuffle, data)
|
219
|
+
|
220
|
+
# Create a transposed pytree
|
221
|
+
def get_transposed(path: PathType, x: PyTree) -> PyTree:
|
222
|
+
if isinstance(x, jnp.ndarray) and path in transposed_data:
|
223
|
+
return transposed_data[path]
|
224
|
+
return x
|
225
|
+
|
226
|
+
transposed_pytree = jax.tree_util.tree_map_with_path(get_transposed, data)
|
227
|
+
|
228
|
+
# Reshuffle the transposed pytree along the leading dimensions
|
229
|
+
reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
|
230
|
+
|
231
|
+
# Transpose back
|
232
|
+
def restore_transpose(path: PathType, x: PyTree) -> PyTree:
|
233
|
+
if isinstance(x, jnp.ndarray) and path in transpose_info:
|
234
|
+
transpose_order, original_shape = transpose_info[path]
|
235
|
+
return transpose_back(x, transpose_order, original_shape)
|
236
|
+
return x
|
237
|
+
|
238
|
+
return jax.tree_util.tree_map_with_path(restore_transpose, reshuffled_transposed)
|
xax/utils/tensorboard.py
CHANGED
@@ -14,7 +14,7 @@ from PIL.Image import Image as PILImage
|
|
14
14
|
from tensorboard.compat.proto.config_pb2 import RunMetadata
|
15
15
|
from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
|
16
16
|
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
17
|
-
from tensorboard.compat.proto.summary_pb2 import Summary, SummaryMetadata
|
17
|
+
from tensorboard.compat.proto.summary_pb2 import HistogramProto, Summary, SummaryMetadata
|
18
18
|
from tensorboard.compat.proto.tensor_pb2 import TensorProto
|
19
19
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
20
20
|
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
|
@@ -25,6 +25,65 @@ from xax.core.state import Phase
|
|
25
25
|
ImageShape = Literal["HWC", "CHW", "HW", "NHWC", "NCHW", "NHW"]
|
26
26
|
|
27
27
|
|
28
|
+
def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | None = None) -> HistogramProto:
|
29
|
+
"""Convert values into a histogram proto using logic from histogram.cc.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
values: Input values to create histogram from
|
33
|
+
bins: Bin specification (string like 'auto' or array of bin edges)
|
34
|
+
max_bins: Optional maximum number of bins
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
HistogramProto containing the histogram data
|
38
|
+
"""
|
39
|
+
if values.size == 0:
|
40
|
+
raise ValueError("The input has no element.")
|
41
|
+
values = values.reshape(-1)
|
42
|
+
counts, limits = np.histogram(values, bins=bins)
|
43
|
+
num_bins = len(counts)
|
44
|
+
|
45
|
+
if max_bins is not None and num_bins > max_bins:
|
46
|
+
subsampling = num_bins // max_bins
|
47
|
+
subsampling_remainder = num_bins % subsampling
|
48
|
+
if subsampling_remainder != 0:
|
49
|
+
counts = np.pad(
|
50
|
+
counts,
|
51
|
+
pad_width=[[0, subsampling - subsampling_remainder]],
|
52
|
+
mode="constant",
|
53
|
+
constant_values=0,
|
54
|
+
)
|
55
|
+
counts = counts.reshape(-1, subsampling).sum(axis=-1)
|
56
|
+
new_limits = np.empty((counts.size + 1,), limits.dtype)
|
57
|
+
new_limits[:-1] = limits[:-1:subsampling]
|
58
|
+
new_limits[-1] = limits[-1]
|
59
|
+
limits = new_limits
|
60
|
+
|
61
|
+
# Find the first and the last bin defining the support of the histogram:
|
62
|
+
cum_counts = np.cumsum(np.greater(counts, 0))
|
63
|
+
start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
|
64
|
+
start = int(start)
|
65
|
+
end = int(end) + 1
|
66
|
+
|
67
|
+
# TensorBoard only includes the right bin limits. To still have the leftmost limit
|
68
|
+
# included, we include an empty bin left.
|
69
|
+
counts = counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
|
70
|
+
limits = limits[start : end + 1]
|
71
|
+
|
72
|
+
if counts.size == 0 or limits.size == 0:
|
73
|
+
raise ValueError("The histogram is empty, please file a bug report.")
|
74
|
+
|
75
|
+
sum_sq = values.dot(values)
|
76
|
+
return HistogramProto(
|
77
|
+
min=values.min(),
|
78
|
+
max=values.max(),
|
79
|
+
num=len(values),
|
80
|
+
sum=values.sum(),
|
81
|
+
sum_squares=sum_sq,
|
82
|
+
bucket_limit=limits.tolist(),
|
83
|
+
bucket=counts.tolist(),
|
84
|
+
)
|
85
|
+
|
86
|
+
|
28
87
|
class TensorboardProtobufWriter:
|
29
88
|
def __init__(
|
30
89
|
self,
|
@@ -263,6 +322,123 @@ class TensorboardWriter:
|
|
263
322
|
walltime=walltime,
|
264
323
|
)
|
265
324
|
|
325
|
+
def add_histogram(
|
326
|
+
self,
|
327
|
+
tag: str,
|
328
|
+
values: np.ndarray,
|
329
|
+
global_step: int | None = None,
|
330
|
+
bins: str | np.ndarray = "auto",
|
331
|
+
walltime: float | None = None,
|
332
|
+
max_bins: int | None = None,
|
333
|
+
) -> None:
|
334
|
+
hist = make_histogram(values.astype(float), bins, max_bins)
|
335
|
+
self.pb_writer.add_summary(
|
336
|
+
Summary(value=[Summary.Value(tag=tag, histo=hist)]),
|
337
|
+
global_step=global_step,
|
338
|
+
walltime=walltime,
|
339
|
+
)
|
340
|
+
|
341
|
+
def add_histogram_raw(
|
342
|
+
self,
|
343
|
+
tag: str,
|
344
|
+
min: float | np.ndarray,
|
345
|
+
max: float | np.ndarray,
|
346
|
+
num: int | np.ndarray,
|
347
|
+
sum: float | np.ndarray,
|
348
|
+
sum_squares: float | np.ndarray,
|
349
|
+
bucket_limits: list[float | np.ndarray],
|
350
|
+
bucket_counts: list[int | np.ndarray],
|
351
|
+
global_step: int | None = None,
|
352
|
+
walltime: float | None = None,
|
353
|
+
) -> None:
|
354
|
+
"""Add histogram with raw data to summary.
|
355
|
+
|
356
|
+
Args:
|
357
|
+
tag: Data identifier
|
358
|
+
min: Min value
|
359
|
+
max: Max value
|
360
|
+
num: Number of values
|
361
|
+
sum: Sum of all values
|
362
|
+
sum_squares: Sum of squares for all values
|
363
|
+
bucket_limits: Upper value per bucket
|
364
|
+
bucket_counts: Number of values per bucket
|
365
|
+
global_step: Global step value to record
|
366
|
+
walltime: Optional override default walltime
|
367
|
+
"""
|
368
|
+
if len(bucket_limits) != len(bucket_counts):
|
369
|
+
raise ValueError("len(bucket_limits) != len(bucket_counts)")
|
370
|
+
|
371
|
+
# Convert numpy arrays to Python types
|
372
|
+
hist = HistogramProto(
|
373
|
+
min=float(min),
|
374
|
+
max=float(max),
|
375
|
+
num=int(num),
|
376
|
+
sum=float(sum),
|
377
|
+
sum_squares=float(sum_squares),
|
378
|
+
bucket_limit=[float(x) for x in bucket_limits],
|
379
|
+
bucket=[int(x) for x in bucket_counts],
|
380
|
+
)
|
381
|
+
self.pb_writer.add_summary(
|
382
|
+
Summary(value=[Summary.Value(tag=tag, histo=hist)]),
|
383
|
+
global_step=global_step,
|
384
|
+
walltime=walltime,
|
385
|
+
)
|
386
|
+
|
387
|
+
def add_gaussian_distribution(
|
388
|
+
self,
|
389
|
+
tag: str,
|
390
|
+
mean: float | np.ndarray,
|
391
|
+
std: float | np.ndarray,
|
392
|
+
bins: int = 16,
|
393
|
+
global_step: int | None = None,
|
394
|
+
walltime: float | None = None,
|
395
|
+
) -> None:
|
396
|
+
"""Add a Gaussian distribution to the summary.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
tag: Data identifier
|
400
|
+
mean: Mean of the Gaussian distribution
|
401
|
+
std: Standard deviation of the Gaussian distribution
|
402
|
+
bins: Number of bins to use for the histogram
|
403
|
+
global_step: Global step value to record
|
404
|
+
walltime: Optional override default walltime
|
405
|
+
"""
|
406
|
+
# Convert numpy arrays to Python types
|
407
|
+
mean = float(mean)
|
408
|
+
std = float(std)
|
409
|
+
|
410
|
+
# Create bin edges spanning ±3 standard deviations
|
411
|
+
bin_edges = np.linspace(mean - 3 * std, mean + 3 * std, bins + 1)
|
412
|
+
|
413
|
+
# Calculate the probability density for each bin
|
414
|
+
bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
|
415
|
+
gaussian_pdf = np.exp(-0.5 * ((bin_centers - mean) / std) ** 2) / (std * np.sqrt(2 * np.pi))
|
416
|
+
|
417
|
+
# Scale the PDF to represent counts
|
418
|
+
num_samples = bins * 1000
|
419
|
+
bucket_counts = (gaussian_pdf * num_samples * (bin_edges[1] - bin_edges[0])).astype(int)
|
420
|
+
|
421
|
+
# Ensure we have at least one count per bin for visualization
|
422
|
+
bucket_counts = np.maximum(bucket_counts, 1)
|
423
|
+
|
424
|
+
# Calculate actual statistics based on the discretized distribution
|
425
|
+
total_counts = float(bucket_counts.sum())
|
426
|
+
weighted_sum = float((bin_centers * bucket_counts).sum())
|
427
|
+
weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
|
428
|
+
|
429
|
+
self.add_histogram_raw(
|
430
|
+
tag=tag,
|
431
|
+
min=float(bin_edges[0]),
|
432
|
+
max=float(bin_edges[-1]),
|
433
|
+
num=int(total_counts),
|
434
|
+
sum=weighted_sum,
|
435
|
+
sum_squares=weighted_sum_squares,
|
436
|
+
bucket_limits=bin_edges[1:].tolist(), # TensorBoard expects right bin edges
|
437
|
+
bucket_counts=bucket_counts.tolist(),
|
438
|
+
global_step=global_step,
|
439
|
+
walltime=walltime,
|
440
|
+
)
|
441
|
+
|
266
442
|
|
267
443
|
class TensorboardWriterKwargs(TypedDict):
|
268
444
|
max_queue_size: int
|
@@ -1,12 +1,13 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.
|
4
|
-
Summary:
|
5
|
-
Home-page: https://github.com/
|
3
|
+
Version: 0.1.1
|
4
|
+
Summary: A library for fast Jax experimentation
|
5
|
+
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
7
7
|
Requires-Python: >=3.11
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
|
+
Requires-Dist: attrs
|
10
11
|
Requires-Dist: jax
|
11
12
|
Requires-Dist: jaxtyping
|
12
13
|
Requires-Dist: equinox
|
@@ -30,10 +31,28 @@ Requires-Dist: pytest; extra == "dev"
|
|
30
31
|
Requires-Dist: types-pillow; extra == "dev"
|
31
32
|
Requires-Dist: types-psutil; extra == "dev"
|
32
33
|
Requires-Dist: types-requests; extra == "dev"
|
34
|
+
Provides-Extra: export
|
35
|
+
Requires-Dist: orbax-export; extra == "export"
|
36
|
+
Requires-Dist: tensorflow; extra == "export"
|
37
|
+
Provides-Extra: flax
|
38
|
+
Requires-Dist: flax; extra == "flax"
|
39
|
+
Provides-Extra: all
|
40
|
+
Requires-Dist: black; extra == "all"
|
41
|
+
Requires-Dist: darglint; extra == "all"
|
42
|
+
Requires-Dist: mypy; extra == "all"
|
43
|
+
Requires-Dist: ruff; extra == "all"
|
44
|
+
Requires-Dist: pytest; extra == "all"
|
45
|
+
Requires-Dist: types-pillow; extra == "all"
|
46
|
+
Requires-Dist: types-psutil; extra == "all"
|
47
|
+
Requires-Dist: types-requests; extra == "all"
|
48
|
+
Requires-Dist: orbax-export; extra == "all"
|
49
|
+
Requires-Dist: tensorflow; extra == "all"
|
50
|
+
Requires-Dist: flax; extra == "all"
|
33
51
|
Dynamic: author
|
34
52
|
Dynamic: description
|
35
53
|
Dynamic: description-content-type
|
36
54
|
Dynamic: home-page
|
55
|
+
Dynamic: license-file
|
37
56
|
Dynamic: provides-extra
|
38
57
|
Dynamic: requires-dist
|
39
58
|
Dynamic: requires-python
|
@@ -1,19 +1,22 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=JyKRACir9b0bkuG93bwxADFrVr-Lo76kenDBJtvb_wQ,13280
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
|
-
xax/requirements.txt,sha256=
|
4
|
+
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
7
7
|
xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
|
+
xax/nn/equinox.py,sha256=1Ck6ycz76dhit2LHX4y2lp3WJSPsDuRt7TK7AxxQhww,4837
|
11
|
+
xax/nn/export.py,sha256=Do2bLjJTD744mxpQuPYpz8fZ3EIjBLaaZfhp8maNVrg,5303
|
10
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
11
|
-
xax/nn/geom.py,sha256=
|
13
|
+
xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
|
14
|
+
xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
|
12
15
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
13
16
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
xax/task/base.py,sha256=
|
15
|
-
xax/task/logger.py,sha256=
|
16
|
-
xax/task/script.py,sha256=
|
17
|
+
xax/task/base.py,sha256=MlH5dTKAiMzFRI5fmXCvL1k8ELbalWMBICeVxmW6k2U,7479
|
18
|
+
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
19
|
+
xax/task/script.py,sha256=zt36Sobdoer86gXHqc4sMAW7bqZRVl6IEExuQZH2USk,926
|
17
20
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
18
21
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
22
|
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
@@ -23,33 +26,35 @@ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
23
26
|
xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
|
24
27
|
xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
|
25
28
|
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
26
|
-
xax/task/loggers/stdout.py,sha256=
|
27
|
-
xax/task/loggers/tensorboard.py,sha256=
|
29
|
+
xax/task/loggers/stdout.py,sha256=bR0k-PfmFgLfPxLPb4hZw_8G_msA32UeHfAAu11nEYs,6757
|
30
|
+
xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
|
28
31
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
29
|
-
xax/task/mixins/artifacts.py,sha256=
|
30
|
-
xax/task/mixins/checkpointing.py,sha256=
|
32
|
+
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
33
|
+
xax/task/mixins/checkpointing.py,sha256=sRkVxJbQfqDf1-lp1KFrAGYWHhTlV8_DORxGQ_69P1A,8954
|
31
34
|
xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
|
32
35
|
xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
|
33
36
|
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
34
37
|
xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
|
35
|
-
xax/task/mixins/logger.py,sha256=
|
38
|
+
xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
|
36
39
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
37
40
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
38
|
-
xax/task/mixins/step_wrapper.py,sha256
|
39
|
-
xax/task/mixins/train.py,sha256=
|
41
|
+
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
42
|
+
xax/task/mixins/train.py,sha256=BEC7HSwBlGZDe7jCsedqEA8-K1Zx52-bTjsBONYIE5g,22225
|
40
43
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
41
|
-
xax/utils/
|
44
|
+
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
|
+
xax/utils/experiments.py,sha256=_cwoBaiBxoQ_Tstm0rz7TEqfELqcktmPflb6AP1K0qA,28779
|
42
46
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
43
|
-
xax/utils/
|
47
|
+
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
48
|
+
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
44
49
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
45
50
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
46
|
-
xax/utils/pytree.py,sha256=
|
47
|
-
xax/utils/tensorboard.py,sha256=
|
51
|
+
xax/utils/pytree.py,sha256=7GjQoPc_ZSZt3QS_9qXoBWl1jfMp1qZa7aViQoWJ0OQ,8864
|
52
|
+
xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,16170
|
48
53
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
49
54
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
50
55
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
51
|
-
xax-0.
|
52
|
-
xax-0.
|
53
|
-
xax-0.
|
54
|
-
xax-0.
|
55
|
-
xax-0.
|
56
|
+
xax-0.1.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
57
|
+
xax-0.1.1.dist-info/METADATA,sha256=tJ4ilL3uBbykHBQTHbh-bN6m4hrHqivyyFeuI33ddX4,1877
|
58
|
+
xax-0.1.1.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
|
59
|
+
xax-0.1.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
60
|
+
xax-0.1.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|