xax 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +102 -2
- xax/core/conf.py +8 -33
- xax/core/state.py +13 -23
- xax/nn/geom.py +75 -0
- xax/requirements.txt +2 -0
- xax/task/base.py +2 -0
- xax/task/logger.py +194 -122
- xax/task/loggers/callback.py +4 -16
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +14 -28
- xax/task/mixins/__init__.py +1 -0
- xax/task/mixins/artifacts.py +7 -4
- xax/task/mixins/checkpointing.py +12 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +16 -5
- xax/task/mixins/data_loader.py +23 -12
- xax/task/mixins/gpu_stats.py +19 -5
- xax/task/mixins/logger.py +4 -2
- xax/task/mixins/process.py +4 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +189 -129
- xax/task/script.py +1 -1
- xax/task/task.py +7 -0
- xax/utils/jax.py +126 -0
- xax/utils/profile.py +61 -0
- xax/utils/pytree.py +50 -0
- xax/utils/tensorboard.py +48 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/METADATA +12 -2
- xax-0.0.7.dist-info/RECORD +55 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.5.dist-info/RECORD +0 -52
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/LICENSE +0 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/top_level.txt +0 -0
xax/utils/profile.py
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
"""Profiling utilities."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import time
|
6
|
+
from functools import wraps
|
7
|
+
from typing import Callable, ParamSpec, TypeVar
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
P = ParamSpec("P") # For function parameters
|
12
|
+
R = TypeVar("R") # For function return type
|
13
|
+
|
14
|
+
|
15
|
+
def profile(fn: Callable[P, R]) -> Callable[P, R]:
|
16
|
+
"""Profiling decorator that tracks function call count and execution time.
|
17
|
+
|
18
|
+
Activated when the PROFILE environment variable is set to "1".
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
A decorated function with profiling capabilities.
|
22
|
+
"""
|
23
|
+
|
24
|
+
class ProfileState:
|
25
|
+
call_count = 0
|
26
|
+
total_time = 0.0
|
27
|
+
|
28
|
+
@wraps(fn)
|
29
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
30
|
+
if os.environ.get("PROFILE", "0") != "1":
|
31
|
+
return fn(*args, **kwargs)
|
32
|
+
|
33
|
+
start_time = time.time()
|
34
|
+
res = fn(*args, **kwargs)
|
35
|
+
end_time = time.time()
|
36
|
+
runtime = end_time - start_time
|
37
|
+
|
38
|
+
ProfileState.call_count += 1
|
39
|
+
ProfileState.total_time += runtime
|
40
|
+
|
41
|
+
# Handle class methods by showing class name
|
42
|
+
if fn.__name__ == "__call__" or (args and hasattr(args[0], "__class__")):
|
43
|
+
try:
|
44
|
+
class_name = args[0].__class__.__name__ + "."
|
45
|
+
except (IndexError, AttributeError):
|
46
|
+
class_name = ""
|
47
|
+
else:
|
48
|
+
class_name = ""
|
49
|
+
|
50
|
+
logger.info(
|
51
|
+
"%s %s - call #%s, took %s seconds, total: %s seconds",
|
52
|
+
class_name,
|
53
|
+
fn.__name__,
|
54
|
+
ProfileState.call_count,
|
55
|
+
runtime,
|
56
|
+
ProfileState.total_time,
|
57
|
+
)
|
58
|
+
|
59
|
+
return res
|
60
|
+
|
61
|
+
return wrapped
|
xax/utils/pytree.py
ADDED
@@ -0,0 +1,50 @@
|
|
1
|
+
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
|
+
|
3
|
+
import chex
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
from jax import Array
|
7
|
+
from jaxtyping import PyTree
|
8
|
+
|
9
|
+
|
10
|
+
def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
11
|
+
"""Get a slice of an array along the first dimension.
|
12
|
+
|
13
|
+
For multi-dimensional arrays, this slices only along the first dimension
|
14
|
+
and keeps all other dimensions intact.
|
15
|
+
"""
|
16
|
+
chex.assert_shape(start, ())
|
17
|
+
chex.assert_shape(slice_length, ())
|
18
|
+
start_indices = (start,) + (0,) * (len(x.shape) - 1)
|
19
|
+
slice_sizes = (slice_length,) + x.shape[1:]
|
20
|
+
|
21
|
+
return jax.lax.dynamic_slice(x, start_indices, slice_sizes)
|
22
|
+
|
23
|
+
|
24
|
+
def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
|
25
|
+
"""Get a slice of a pytree."""
|
26
|
+
return jax.tree_util.tree_map(lambda x: slice_array(x, start, slice_length), pytree)
|
27
|
+
|
28
|
+
|
29
|
+
def flatten_array(x: Array, flatten_size: int) -> Array:
|
30
|
+
"""Flatten an array into a (flatten_size, ...) array."""
|
31
|
+
reshaped = jnp.reshape(x, (flatten_size, *x.shape[2:]))
|
32
|
+
assert reshaped.shape[0] == flatten_size
|
33
|
+
return reshaped
|
34
|
+
|
35
|
+
|
36
|
+
def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
|
37
|
+
"""Flatten a pytree into a (flatten_size, ...) pytree."""
|
38
|
+
return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
|
39
|
+
|
40
|
+
|
41
|
+
def compute_nan_ratio(pytree: PyTree) -> Array:
|
42
|
+
"""Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
|
43
|
+
nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
|
44
|
+
total_counts = jax.tree_util.tree_map(lambda x: x.size, pytree)
|
45
|
+
|
46
|
+
total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
|
47
|
+
total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
|
48
|
+
overall_nan_ratio = jnp.array(total_nans / total_elements)
|
49
|
+
|
50
|
+
return overall_nan_ratio
|
xax/utils/tensorboard.py
CHANGED
@@ -2,10 +2,14 @@
|
|
2
2
|
|
3
3
|
import functools
|
4
4
|
import io
|
5
|
+
import os
|
6
|
+
import tempfile
|
5
7
|
import time
|
6
8
|
from pathlib import Path
|
7
9
|
from typing import Literal, TypedDict
|
8
10
|
|
11
|
+
import numpy as np
|
12
|
+
import PIL.Image
|
9
13
|
from PIL.Image import Image as PILImage
|
10
14
|
from tensorboard.compat.proto.config_pb2 import RunMetadata
|
11
15
|
from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
|
@@ -186,6 +190,50 @@ class TensorboardWriter:
|
|
186
190
|
walltime=walltime,
|
187
191
|
)
|
188
192
|
|
193
|
+
def add_video(
|
194
|
+
self,
|
195
|
+
tag: str,
|
196
|
+
value: np.ndarray,
|
197
|
+
global_step: int | None = None,
|
198
|
+
walltime: float | None = None,
|
199
|
+
fps: int = 30,
|
200
|
+
) -> None:
|
201
|
+
assert value.ndim == 4, "Video must be 4D array (T, H, W, C)"
|
202
|
+
images = [PIL.Image.fromarray(frame) for frame in value]
|
203
|
+
|
204
|
+
# Create temporary file for GIF
|
205
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
|
206
|
+
try:
|
207
|
+
images[0].save(temp_file.name, save_all=True, append_images=images[1:], duration=int(1000 / fps), loop=0)
|
208
|
+
with open(temp_file.name, "rb") as f:
|
209
|
+
video_string = f.read()
|
210
|
+
|
211
|
+
finally:
|
212
|
+
# Clean up temporary file
|
213
|
+
try:
|
214
|
+
os.remove(temp_file.name)
|
215
|
+
except OSError:
|
216
|
+
pass
|
217
|
+
|
218
|
+
# Add to summary
|
219
|
+
self.pb_writer.add_summary(
|
220
|
+
Summary(
|
221
|
+
value=[
|
222
|
+
Summary.Value(
|
223
|
+
tag=tag,
|
224
|
+
image=Summary.Image(
|
225
|
+
height=value.shape[1],
|
226
|
+
width=value.shape[2],
|
227
|
+
colorspace=value.shape[3],
|
228
|
+
encoded_image_string=video_string,
|
229
|
+
),
|
230
|
+
),
|
231
|
+
],
|
232
|
+
),
|
233
|
+
global_step=global_step,
|
234
|
+
walltime=walltime,
|
235
|
+
)
|
236
|
+
|
189
237
|
def add_text(
|
190
238
|
self,
|
191
239
|
tag: str,
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.2
|
2
2
|
Name: xax
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.7
|
4
4
|
Summary: The xax project
|
5
5
|
Home-page: https://github.com/dpshai/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -12,6 +12,8 @@ Requires-Dist: jaxtyping
|
|
12
12
|
Requires-Dist: equinox
|
13
13
|
Requires-Dist: optax
|
14
14
|
Requires-Dist: dpshdl
|
15
|
+
Requires-Dist: chex
|
16
|
+
Requires-Dist: importlib-resources
|
15
17
|
Requires-Dist: cloudpickle
|
16
18
|
Requires-Dist: pillow
|
17
19
|
Requires-Dist: omegaconf
|
@@ -28,6 +30,14 @@ Requires-Dist: pytest; extra == "dev"
|
|
28
30
|
Requires-Dist: types-pillow; extra == "dev"
|
29
31
|
Requires-Dist: types-psutil; extra == "dev"
|
30
32
|
Requires-Dist: types-requests; extra == "dev"
|
33
|
+
Dynamic: author
|
34
|
+
Dynamic: description
|
35
|
+
Dynamic: description-content-type
|
36
|
+
Dynamic: home-page
|
37
|
+
Dynamic: provides-extra
|
38
|
+
Dynamic: requires-dist
|
39
|
+
Dynamic: requires-python
|
40
|
+
Dynamic: summary
|
31
41
|
|
32
42
|
# xax
|
33
43
|
|
@@ -0,0 +1,55 @@
|
|
1
|
+
xax/__init__.py,sha256=ScTkvKaxgpuKhhs9RINJa2XWCj899ndSYrB3FtScfxw,10509
|
2
|
+
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
|
+
xax/requirements.txt,sha256=NmU9PNJhfLtNqqtWWf8WqMjgbBPCn_yt8oMGAgS7Fno,291
|
5
|
+
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
7
|
+
xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
|
8
|
+
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
|
+
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
11
|
+
xax/nn/geom.py,sha256=MtVar9AdqrJQGIFxcIFHyFnV_fblf9Pc4kQT_gTQASI,2195
|
12
|
+
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
13
|
+
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
xax/task/base.py,sha256=LHDmM2c_Ps5cGEzn_QUpmyInD7zJJm3Yt9eSeij2Vus,7297
|
15
|
+
xax/task/logger.py,sha256=orN1jmM4SIR2EiYk8bNoJZscmhX1FytADBU6p9qpows,29256
|
16
|
+
xax/task/script.py,sha256=4LyXrpj0V36TjAZT4lvQeiOTqa5U2tommHKwgWDCE24,1025
|
17
|
+
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
18
|
+
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
+
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
20
|
+
xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
|
21
|
+
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
22
|
+
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
+
xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
|
24
|
+
xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
|
25
|
+
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
26
|
+
xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
|
27
|
+
xax/task/loggers/tensorboard.py,sha256=FGW96z77oG0Kf3cO6Zznx5U3kJNzPWcuSkpY4RnbFCo,6909
|
28
|
+
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
29
|
+
xax/task/mixins/artifacts.py,sha256=1H7ZbR-KSsXhVtqGVlqMi-TXfn1-dM7YnTCLVuw594s,3835
|
30
|
+
xax/task/mixins/checkpointing.py,sha256=AMlobojybvJdDZcNCxm1DHSVC_2Qvnu_MbRcsc_8eoA,8508
|
31
|
+
xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
|
32
|
+
xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
|
33
|
+
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
34
|
+
xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
|
35
|
+
xax/task/mixins/logger.py,sha256=CIQ4w4K3FcxN6A9xUfITdVkulSxPa4iaTe6cbs9ruaM,1958
|
36
|
+
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
37
|
+
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
38
|
+
xax/task/mixins/step_wrapper.py,sha256=DJw42mUGwgKx2tkeqatKR9_F4J8ug4wmxKMeJPmhcVQ,1560
|
39
|
+
xax/task/mixins/train.py,sha256=dhGL_IuDaJy39BooYlO7JO-_EotKldtBhBplDGU_AnM,21745
|
40
|
+
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
41
|
+
xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
|
42
|
+
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
43
|
+
xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
|
44
|
+
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
45
|
+
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
46
|
+
xax/utils/pytree.py,sha256=Jwx6ErJfv1r2D23D4eKz1Hoo3mAJ0SEqC3EagZarWkw,1858
|
47
|
+
xax/utils/tensorboard.py,sha256=oGq2E3Yr0z2xaACv2UOVt_CHEVc8fBxI8V1M99Fd34E,9742
|
48
|
+
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
49
|
+
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
50
|
+
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
51
|
+
xax-0.0.7.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
52
|
+
xax-0.0.7.dist-info/METADATA,sha256=hE0KO4kYcN6Ed8iZ4649R5ENOUaQysBMW9vTh-94d4I,1171
|
53
|
+
xax-0.0.7.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
54
|
+
xax-0.0.7.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
55
|
+
xax-0.0.7.dist-info/RECORD,,
|
xax/task/launchers/staged.py
DELETED
@@ -1,29 +0,0 @@
|
|
1
|
-
"""Defines a base class with utility functions for staged training runs."""
|
2
|
-
|
3
|
-
from abc import ABC
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
from xax.task.launchers.base import BaseLauncher
|
7
|
-
from xax.task.mixins.artifacts import ArtifactsMixin, Config
|
8
|
-
|
9
|
-
|
10
|
-
class StagedLauncher(BaseLauncher, ABC):
|
11
|
-
def __init__(self, config_file_name: str = "config.yaml") -> None:
|
12
|
-
super().__init__()
|
13
|
-
|
14
|
-
self.config_file_name = config_file_name
|
15
|
-
|
16
|
-
def get_config_path(self, task: "ArtifactsMixin[Config]", use_cli: bool | list[str] = True) -> Path:
|
17
|
-
config_path = task.exp_dir / self.config_file_name
|
18
|
-
task.config.exp_dir = str(task.exp_dir)
|
19
|
-
with open(config_path, "w", encoding="utf-8") as f:
|
20
|
-
f.write(task.config_str(task.config, use_cli=use_cli))
|
21
|
-
return config_path
|
22
|
-
|
23
|
-
@classmethod
|
24
|
-
def from_components(cls, task_key: str, config_path: Path, use_cli: bool | list[str] = True) -> "ArtifactsMixin":
|
25
|
-
return (
|
26
|
-
ArtifactsMixin.from_task_key(task_key)
|
27
|
-
.get_task(config_path, use_cli=use_cli)
|
28
|
-
.set_exp_dir(config_path.parent)
|
29
|
-
)
|
xax-0.0.5.dist-info/RECORD
DELETED
@@ -1,52 +0,0 @@
|
|
1
|
-
xax/__init__.py,sha256=3OQTnHGYgaux3i9gTYZxfK8F2zS_hK2QqD-G-Z1TfHQ,7623
|
2
|
-
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
|
-
xax/requirements.txt,sha256=DRn2B9d3mAr57-U3IOIrKm2nYz8H3cYgDy6EIC3SsuE,266
|
5
|
-
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
-
xax/core/conf.py,sha256=hwgc5sJw0YRSegQLLrmIDtscev-H_a2ST1-V6BJ5aec,5915
|
7
|
-
xax/core/state.py,sha256=7lnVSytuhwPfcobPGdjfQ0QxbLgzWQNipKwXchd58QI,2695
|
8
|
-
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
-
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
|
-
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
11
|
-
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
12
|
-
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
xax/task/base.py,sha256=n82Sw-kMLr-WZzh0c_vAAQ2b-DHRYs0U8biPRonBxKU,7252
|
14
|
-
xax/task/logger.py,sha256=MAFIgd6yO0pD3gJHfKTwUDcwaM8DZD3AZtFLvrQtlFo,26740
|
15
|
-
xax/task/script.py,sha256=oBGnScYa_X284fCajabPCcbaSEIqR8nO4d40dvMv3NQ,1011
|
16
|
-
xax/task/task.py,sha256=X7TV_gt6C4m_-Il22Uyr5iMm-eh15oH5v1dl96sv1go,1295
|
17
|
-
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
|
-
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
19
|
-
xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
|
20
|
-
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
21
|
-
xax/task/launchers/staged.py,sha256=jYeT9u58CN4ldV-ltJiQXQglEWOnEckHWnHYjfJQaoY,1102
|
22
|
-
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
-
xax/task/loggers/callback.py,sha256=reaRuJs5iB6WWNgh3_tsuz_QPAlBC-5Ed2wCG_6Wj4M,2075
|
24
|
-
xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
|
25
|
-
xax/task/loggers/state.py,sha256=qyb-q8MdagN7BX-DhKucwoc45tIZJrPuvVDVoysTKC4,1576
|
26
|
-
xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
|
27
|
-
xax/task/loggers/tensorboard.py,sha256=DMYRDCQ9c-xHqO4kkZvc1-53PXCf2gX0aRiiAQDtHJ0,7293
|
28
|
-
xax/task/mixins/__init__.py,sha256=NkSAjMN5jpXE6LROIwMzX60z7UsTBpGs624_mNUWquo,745
|
29
|
-
xax/task/mixins/artifacts.py,sha256=G0984WuXII_R13IlJZn9En7iM83ISXKjeVYvn7j4wBs,3754
|
30
|
-
xax/task/mixins/checkpointing.py,sha256=JV91b5xyBUyZIbR3S-5UkBZNoAZYCnWx7Y-ayuU0lHQ,7989
|
31
|
-
xax/task/mixins/cpu_stats.py,sha256=Lqskt1t4usE6UslhANjwB0ZKOYmaC4dm9dnVKa6ERdA,8924
|
32
|
-
xax/task/mixins/data_loader.py,sha256=BPs0sYdctesnhS9nQ1rvT77MzLXznw5E4tAzWT1PpJY,5998
|
33
|
-
xax/task/mixins/gpu_stats.py,sha256=tFTNmtl9iMiLiYJSPg7gHR-ZxOP4P_ynzSmYNIAUoRw,8431
|
34
|
-
xax/task/mixins/logger.py,sha256=6XkjP_YUGY2CiDry0kDm1f9jqzJaLa1bPVYYnGjvSBU,2049
|
35
|
-
xax/task/mixins/process.py,sha256=HQAvEruvvfcS_IThrM4hKhFHZCAN2kFY_vEaZGLeZS8,1428
|
36
|
-
xax/task/mixins/runnable.py,sha256=d5-qyIpmNPtbTzE7qFJGGCPSREEDhX1VApUJPNDWye0,1933
|
37
|
-
xax/task/mixins/step_wrapper.py,sha256=Do4eGgZVuqDX9ZGDxQdfn6pRbUnHjQBAkTF0vnNH31E,1472
|
38
|
-
xax/task/mixins/train.py,sha256=Xeb0N9j-Znz5QnMDCXDGPqUSKMNLJkd8oF8giN45l2U,20099
|
39
|
-
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
|
-
xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
|
41
|
-
xax/utils/jax.py,sha256=VzEVB766UyH3_cgN6UP0FkCsDuGlYg5KJj8YJS4yYUk,439
|
42
|
-
xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
|
43
|
-
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
44
|
-
xax/utils/tensorboard.py,sha256=XqxUlryFVsb75jE36uLcuoUhSr3nWg_-dzji2h6U_rI,8245
|
45
|
-
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
46
|
-
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
|
-
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
48
|
-
xax-0.0.5.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
49
|
-
xax-0.0.5.dist-info/METADATA,sha256=VCiQmbjwZtiuORVyB0dloFTgLWtnK4o3FaolNWvf-A4,937
|
50
|
-
xax-0.0.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
51
|
-
xax-0.0.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
52
|
-
xax-0.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|