xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +122 -8
- xax/core/conf.py +9 -33
- xax/core/state.py +13 -23
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +17 -10
- xax/task/base.py +2 -6
- xax/task/logger.py +419 -412
- xax/task/loggers/callback.py +44 -0
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +16 -33
- xax/task/mixins/__init__.py +3 -1
- xax/task/mixins/artifacts.py +19 -9
- xax/task/mixins/checkpointing.py +221 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +26 -15
- xax/task/mixins/data_loader.py +27 -19
- xax/task/mixins/gpu_stats.py +22 -8
- xax/task/mixins/logger.py +5 -251
- xax/task/mixins/process.py +8 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +236 -145
- xax/task/script.py +1 -1
- xax/task/task.py +13 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +89 -21
- xax-0.0.6.dist-info/METADATA +50 -0
- xax-0.0.6.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.3.dist-info/METADATA +0 -39
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/top_level.txt +0 -0
xax/utils/experiments.py
CHANGED
@@ -8,6 +8,7 @@ import hashlib
|
|
8
8
|
import inspect
|
9
9
|
import itertools
|
10
10
|
import logging
|
11
|
+
import math
|
11
12
|
import os
|
12
13
|
import random
|
13
14
|
import re
|
@@ -30,7 +31,7 @@ import requests
|
|
30
31
|
from jaxtyping import Array
|
31
32
|
from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
|
32
33
|
|
33
|
-
from xax.core.conf import get_data_dir, get_pretrained_models_dir
|
34
|
+
from xax.core.conf import get_data_dir, get_pretrained_models_dir, load_user_config
|
34
35
|
from xax.core.state import State
|
35
36
|
from xax.utils.text import colored
|
36
37
|
|
@@ -665,6 +666,7 @@ class BaseFileDownloader(ABC):
|
|
665
666
|
with requests.Session() as session:
|
666
667
|
response = session.get(url, params=params, stream=True)
|
667
668
|
|
669
|
+
token: str | None = None
|
668
670
|
for key, value in response.cookies.items():
|
669
671
|
if key.startswith("download_warning"):
|
670
672
|
token = value
|
@@ -756,3 +758,45 @@ def get_state_dict_prefix(
|
|
756
758
|
if regexp is not None:
|
757
759
|
ckpt = {k: v for k, v in ckpt.items() if regexp.match(k)}
|
758
760
|
return ckpt
|
761
|
+
|
762
|
+
|
763
|
+
def split_n_items_across_workers(n: int, worker_id: int, num_workers: int) -> tuple[int, int]:
|
764
|
+
"""Computes offsets for splitting N items across K workers.
|
765
|
+
|
766
|
+
This returns the start and end indices for the items to be processed by the
|
767
|
+
given worker. The end index is exclusive.
|
768
|
+
|
769
|
+
Args:
|
770
|
+
n: The number of items to process.
|
771
|
+
worker_id: The ID of the current worker.
|
772
|
+
num_workers: The total number of workers.
|
773
|
+
|
774
|
+
Returns:
|
775
|
+
The start and end index for the items in the current worker.
|
776
|
+
"""
|
777
|
+
assert n >= num_workers, f"n ({n}) must be >= num_workers ({num_workers})"
|
778
|
+
assert 0 <= worker_id < num_workers, f"worker_id ({worker_id}) must be >= 0 and < num_workers ({num_workers})"
|
779
|
+
|
780
|
+
# The number of items to process per worker.
|
781
|
+
items_per_worker = math.ceil(n / num_workers)
|
782
|
+
|
783
|
+
# The start and end indices for the items to process.
|
784
|
+
start = worker_id * items_per_worker
|
785
|
+
end = min(start + items_per_worker, n)
|
786
|
+
|
787
|
+
return start, end
|
788
|
+
|
789
|
+
|
790
|
+
def num_workers(default: int) -> int:
|
791
|
+
max_workers = load_user_config().experiment.max_workers
|
792
|
+
if hasattr(os, "sched_getaffinity"):
|
793
|
+
try:
|
794
|
+
return min(len(os.sched_getaffinity(0)), max_workers)
|
795
|
+
except Exception:
|
796
|
+
pass
|
797
|
+
if (cpu_count := os.cpu_count()) is not None:
|
798
|
+
return min(cpu_count, max_workers)
|
799
|
+
return min(default, max_workers)
|
800
|
+
|
801
|
+
|
802
|
+
OmegaConf.register_new_resolver("mlfab.num_workers", num_workers, replace=True)
|
xax/utils/logging.py
CHANGED
@@ -5,6 +5,8 @@ import math
|
|
5
5
|
import socket
|
6
6
|
import sys
|
7
7
|
|
8
|
+
from omegaconf import OmegaConf
|
9
|
+
|
8
10
|
from xax.core.conf import load_user_config
|
9
11
|
from xax.utils.text import Color, color_parts, colored
|
10
12
|
|
@@ -175,6 +177,33 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
|
|
175
177
|
logging.getLogger("torch").setLevel(logging.WARNING)
|
176
178
|
|
177
179
|
|
180
|
+
def get_unused_port(default: int | None = None) -> int:
|
181
|
+
"""Returns an unused port number on the local machine.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
default: A default port to try before trying other ports.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
A port number which is currently unused
|
188
|
+
"""
|
189
|
+
if default is not None:
|
190
|
+
sock = socket.socket()
|
191
|
+
try:
|
192
|
+
sock.bind(("", default))
|
193
|
+
return default
|
194
|
+
except OSError:
|
195
|
+
pass
|
196
|
+
finally:
|
197
|
+
sock.close()
|
198
|
+
|
199
|
+
sock = socket.socket()
|
200
|
+
sock.bind(("", 0))
|
201
|
+
return sock.getsockname()[1]
|
202
|
+
|
203
|
+
|
204
|
+
OmegaConf.register_new_resolver("mlfab.unused_port", get_unused_port, replace=True)
|
205
|
+
|
206
|
+
|
178
207
|
def port_is_busy(port: int) -> int:
|
179
208
|
"""Checks whether a port is busy.
|
180
209
|
|
xax/utils/tensorboard.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1
1
|
"""Defines utility functions for interfacing with Tensorboard."""
|
2
2
|
|
3
3
|
import functools
|
4
|
+
import io
|
5
|
+
import os
|
6
|
+
import tempfile
|
4
7
|
import time
|
5
8
|
from pathlib import Path
|
6
9
|
from typing import Literal, TypedDict
|
7
10
|
|
8
11
|
import numpy as np
|
12
|
+
import PIL.Image
|
9
13
|
from PIL.Image import Image as PILImage
|
10
14
|
from tensorboard.compat.proto.config_pb2 import RunMetadata
|
11
15
|
from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
|
@@ -117,47 +121,111 @@ class TensorboardWriter:
|
|
117
121
|
value: float,
|
118
122
|
global_step: int | None = None,
|
119
123
|
walltime: float | None = None,
|
124
|
+
new_style: bool = True,
|
120
125
|
double_precision: bool = False,
|
121
126
|
) -> None:
|
127
|
+
if new_style:
|
128
|
+
self.pb_writer.add_summary(
|
129
|
+
Summary(
|
130
|
+
value=[
|
131
|
+
Summary.Value(
|
132
|
+
tag=tag,
|
133
|
+
tensor=(
|
134
|
+
TensorProto(double_val=[value], dtype="DT_DOUBLE")
|
135
|
+
if double_precision
|
136
|
+
else TensorProto(float_val=[value], dtype="DT_FLOAT")
|
137
|
+
),
|
138
|
+
metadata=SummaryMetadata(
|
139
|
+
plugin_data=SummaryMetadata.PluginData(
|
140
|
+
plugin_name="scalars",
|
141
|
+
),
|
142
|
+
),
|
143
|
+
)
|
144
|
+
],
|
145
|
+
),
|
146
|
+
global_step=global_step,
|
147
|
+
walltime=walltime,
|
148
|
+
)
|
149
|
+
else:
|
150
|
+
self.pb_writer.add_summary(
|
151
|
+
Summary(
|
152
|
+
value=[
|
153
|
+
Summary.Value(
|
154
|
+
tag=tag,
|
155
|
+
simple_value=value,
|
156
|
+
),
|
157
|
+
],
|
158
|
+
),
|
159
|
+
global_step=global_step,
|
160
|
+
walltime=walltime,
|
161
|
+
)
|
162
|
+
|
163
|
+
def add_image(
|
164
|
+
self,
|
165
|
+
tag: str,
|
166
|
+
value: PILImage,
|
167
|
+
global_step: int | None = None,
|
168
|
+
walltime: float | None = None,
|
169
|
+
) -> None:
|
170
|
+
output = io.BytesIO()
|
171
|
+
value.convert("RGB").save(output, format="PNG")
|
172
|
+
image_string = output.getvalue()
|
173
|
+
output.close()
|
174
|
+
|
122
175
|
self.pb_writer.add_summary(
|
123
176
|
Summary(
|
124
177
|
value=[
|
125
178
|
Summary.Value(
|
126
179
|
tag=tag,
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
180
|
+
image=Summary.Image(
|
181
|
+
height=value.height,
|
182
|
+
width=value.width,
|
183
|
+
colorspace=3, # RGB
|
184
|
+
encoded_image_string=image_string,
|
131
185
|
),
|
132
|
-
|
133
|
-
plugin_data=SummaryMetadata.PluginData(
|
134
|
-
plugin_name="scalaras",
|
135
|
-
),
|
136
|
-
),
|
137
|
-
)
|
186
|
+
),
|
138
187
|
],
|
139
188
|
),
|
140
189
|
global_step=global_step,
|
141
190
|
walltime=walltime,
|
142
191
|
)
|
143
192
|
|
144
|
-
def
|
193
|
+
def add_video(
|
145
194
|
self,
|
146
195
|
tag: str,
|
147
|
-
value:
|
196
|
+
value: np.ndarray,
|
148
197
|
global_step: int | None = None,
|
149
198
|
walltime: float | None = None,
|
199
|
+
fps: int = 30,
|
150
200
|
) -> None:
|
151
|
-
|
201
|
+
assert value.ndim == 4, "Video must be 4D array (T, H, W, C)"
|
202
|
+
images = [PIL.Image.fromarray(frame) for frame in value]
|
203
|
+
|
204
|
+
# Create temporary file for GIF
|
205
|
+
temp_file = tempfile.NamedTemporaryFile(suffix=".gif", delete=False)
|
206
|
+
try:
|
207
|
+
images[0].save(temp_file.name, save_all=True, append_images=images[1:], duration=int(1000 / fps), loop=0)
|
208
|
+
with open(temp_file.name, "rb") as f:
|
209
|
+
video_string = f.read()
|
210
|
+
|
211
|
+
finally:
|
212
|
+
# Clean up temporary file
|
213
|
+
try:
|
214
|
+
os.remove(temp_file.name)
|
215
|
+
except OSError:
|
216
|
+
pass
|
217
|
+
|
218
|
+
# Add to summary
|
152
219
|
self.pb_writer.add_summary(
|
153
220
|
Summary(
|
154
221
|
value=[
|
155
222
|
Summary.Value(
|
156
223
|
tag=tag,
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
224
|
+
image=Summary.Image(
|
225
|
+
height=value.shape[1],
|
226
|
+
width=value.shape[2],
|
227
|
+
colorspace=value.shape[3],
|
228
|
+
encoded_image_string=video_string,
|
161
229
|
),
|
162
230
|
),
|
163
231
|
],
|
@@ -197,7 +265,6 @@ class TensorboardWriter:
|
|
197
265
|
|
198
266
|
|
199
267
|
class TensorboardWriterKwargs(TypedDict):
|
200
|
-
log_directory: Path
|
201
268
|
max_queue_size: int
|
202
269
|
flush_seconds: float
|
203
270
|
filename_suffix: str
|
@@ -213,8 +280,9 @@ class TensorboardWriters:
|
|
213
280
|
) -> None:
|
214
281
|
super().__init__()
|
215
282
|
|
283
|
+
self.log_directory = Path(log_directory)
|
284
|
+
|
216
285
|
self.kwargs: TensorboardWriterKwargs = {
|
217
|
-
"log_directory": Path(log_directory),
|
218
286
|
"max_queue_size": max_queue_size,
|
219
287
|
"flush_seconds": flush_seconds,
|
220
288
|
"filename_suffix": filename_suffix,
|
@@ -222,11 +290,11 @@ class TensorboardWriters:
|
|
222
290
|
|
223
291
|
@functools.cached_property
|
224
292
|
def train_writer(self) -> TensorboardWriter:
|
225
|
-
return TensorboardWriter(**self.kwargs)
|
293
|
+
return TensorboardWriter(self.log_directory / "train", **self.kwargs)
|
226
294
|
|
227
295
|
@functools.cached_property
|
228
296
|
def valid_writer(self) -> TensorboardWriter:
|
229
|
-
return TensorboardWriter(**self.kwargs)
|
297
|
+
return TensorboardWriter(self.log_directory / "valid", **self.kwargs)
|
230
298
|
|
231
299
|
def writer(self, phase: Phase) -> TensorboardWriter:
|
232
300
|
match phase:
|
@@ -0,0 +1,50 @@
|
|
1
|
+
Metadata-Version: 2.2
|
2
|
+
Name: xax
|
3
|
+
Version: 0.0.6
|
4
|
+
Summary: The xax project
|
5
|
+
Home-page: https://github.com/dpshai/xax
|
6
|
+
Author: Benjamin Bolte
|
7
|
+
Requires-Python: >=3.11
|
8
|
+
Description-Content-Type: text/markdown
|
9
|
+
License-File: LICENSE
|
10
|
+
Requires-Dist: jax
|
11
|
+
Requires-Dist: jaxtyping
|
12
|
+
Requires-Dist: equinox
|
13
|
+
Requires-Dist: optax
|
14
|
+
Requires-Dist: dpshdl
|
15
|
+
Requires-Dist: chex
|
16
|
+
Requires-Dist: importlib-resources
|
17
|
+
Requires-Dist: cloudpickle
|
18
|
+
Requires-Dist: pillow
|
19
|
+
Requires-Dist: omegaconf
|
20
|
+
Requires-Dist: gitpython
|
21
|
+
Requires-Dist: tensorboard
|
22
|
+
Requires-Dist: psutil
|
23
|
+
Requires-Dist: requests
|
24
|
+
Provides-Extra: dev
|
25
|
+
Requires-Dist: black; extra == "dev"
|
26
|
+
Requires-Dist: darglint; extra == "dev"
|
27
|
+
Requires-Dist: mypy; extra == "dev"
|
28
|
+
Requires-Dist: ruff; extra == "dev"
|
29
|
+
Requires-Dist: pytest; extra == "dev"
|
30
|
+
Requires-Dist: types-pillow; extra == "dev"
|
31
|
+
Requires-Dist: types-psutil; extra == "dev"
|
32
|
+
Requires-Dist: types-requests; extra == "dev"
|
33
|
+
Dynamic: author
|
34
|
+
Dynamic: description
|
35
|
+
Dynamic: description-content-type
|
36
|
+
Dynamic: home-page
|
37
|
+
Dynamic: provides-extra
|
38
|
+
Dynamic: requires-dist
|
39
|
+
Dynamic: requires-python
|
40
|
+
Dynamic: summary
|
41
|
+
|
42
|
+
# xax
|
43
|
+
|
44
|
+
JAX library for fast experimentation.
|
45
|
+
|
46
|
+
## Installation
|
47
|
+
|
48
|
+
```bash
|
49
|
+
pip install xax
|
50
|
+
```
|
@@ -0,0 +1,52 @@
|
|
1
|
+
xax/__init__.py,sha256=RTUsDh_R0TFa09q-_U0vd-eCYRC-bCaHqHlayp8U2hU,9736
|
2
|
+
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
|
+
xax/requirements.txt,sha256=NmU9PNJhfLtNqqtWWf8WqMjgbBPCn_yt8oMGAgS7Fno,291
|
5
|
+
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
7
|
+
xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
|
8
|
+
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
|
+
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
11
|
+
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
12
|
+
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
+
xax/task/base.py,sha256=LHDmM2c_Ps5cGEzn_QUpmyInD7zJJm3Yt9eSeij2Vus,7297
|
14
|
+
xax/task/logger.py,sha256=orN1jmM4SIR2EiYk8bNoJZscmhX1FytADBU6p9qpows,29256
|
15
|
+
xax/task/script.py,sha256=4LyXrpj0V36TjAZT4lvQeiOTqa5U2tommHKwgWDCE24,1025
|
16
|
+
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
17
|
+
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
|
+
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
19
|
+
xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
|
20
|
+
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
21
|
+
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
|
+
xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
|
23
|
+
xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
|
24
|
+
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
25
|
+
xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
|
26
|
+
xax/task/loggers/tensorboard.py,sha256=FGW96z77oG0Kf3cO6Zznx5U3kJNzPWcuSkpY4RnbFCo,6909
|
27
|
+
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
28
|
+
xax/task/mixins/artifacts.py,sha256=1H7ZbR-KSsXhVtqGVlqMi-TXfn1-dM7YnTCLVuw594s,3835
|
29
|
+
xax/task/mixins/checkpointing.py,sha256=AMlobojybvJdDZcNCxm1DHSVC_2Qvnu_MbRcsc_8eoA,8508
|
30
|
+
xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
|
31
|
+
xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
|
32
|
+
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
33
|
+
xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
|
34
|
+
xax/task/mixins/logger.py,sha256=CIQ4w4K3FcxN6A9xUfITdVkulSxPa4iaTe6cbs9ruaM,1958
|
35
|
+
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
36
|
+
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
37
|
+
xax/task/mixins/step_wrapper.py,sha256=DJw42mUGwgKx2tkeqatKR9_F4J8ug4wmxKMeJPmhcVQ,1560
|
38
|
+
xax/task/mixins/train.py,sha256=dhGL_IuDaJy39BooYlO7JO-_EotKldtBhBplDGU_AnM,21745
|
39
|
+
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
|
+
xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
|
41
|
+
xax/utils/jax.py,sha256=VzEVB766UyH3_cgN6UP0FkCsDuGlYg5KJj8YJS4yYUk,439
|
42
|
+
xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
|
43
|
+
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
44
|
+
xax/utils/tensorboard.py,sha256=oGq2E3Yr0z2xaACv2UOVt_CHEVc8fBxI8V1M99Fd34E,9742
|
45
|
+
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
46
|
+
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
|
+
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
48
|
+
xax-0.0.6.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
49
|
+
xax-0.0.6.dist-info/METADATA,sha256=YO2c2PUMWkH1ILfPhFWKK4Sodbo9qUpUOCIkm4aLHfg,1171
|
50
|
+
xax-0.0.6.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
51
|
+
xax-0.0.6.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
52
|
+
xax-0.0.6.dist-info/RECORD,,
|
xax/task/launchers/staged.py
DELETED
@@ -1,29 +0,0 @@
|
|
1
|
-
"""Defines a base class with utility functions for staged training runs."""
|
2
|
-
|
3
|
-
from abc import ABC
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
from xax.task.launchers.base import BaseLauncher
|
7
|
-
from xax.task.mixins.artifacts import ArtifactsMixin, Config
|
8
|
-
|
9
|
-
|
10
|
-
class StagedLauncher(BaseLauncher, ABC):
|
11
|
-
def __init__(self, config_file_name: str = "config.yaml") -> None:
|
12
|
-
super().__init__()
|
13
|
-
|
14
|
-
self.config_file_name = config_file_name
|
15
|
-
|
16
|
-
def get_config_path(self, task: "ArtifactsMixin[Config]", use_cli: bool | list[str] = True) -> Path:
|
17
|
-
config_path = task.exp_dir / self.config_file_name
|
18
|
-
task.config.exp_dir = str(task.exp_dir)
|
19
|
-
with open(config_path, "w", encoding="utf-8") as f:
|
20
|
-
f.write(task.config_str(task.config, use_cli=use_cli))
|
21
|
-
return config_path
|
22
|
-
|
23
|
-
@classmethod
|
24
|
-
def from_components(cls, task_key: str, config_path: Path, use_cli: bool | list[str] = True) -> "ArtifactsMixin":
|
25
|
-
return (
|
26
|
-
ArtifactsMixin.from_task_key(task_key)
|
27
|
-
.get_task(config_path, use_cli=use_cli)
|
28
|
-
.set_exp_dir(config_path.parent)
|
29
|
-
)
|
xax-0.0.3.dist-info/METADATA
DELETED
@@ -1,39 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.1
|
2
|
-
Name: xax
|
3
|
-
Version: 0.0.3
|
4
|
-
Summary: The xax project
|
5
|
-
Home-page: https://github.com/dpshai/xax
|
6
|
-
Author: Benjamin Bolte
|
7
|
-
Requires-Python: >=3.11
|
8
|
-
Description-Content-Type: text/markdown
|
9
|
-
License-File: LICENSE
|
10
|
-
Requires-Dist: dpshdl
|
11
|
-
Requires-Dist: equinox
|
12
|
-
Requires-Dist: gitpython
|
13
|
-
Requires-Dist: jax
|
14
|
-
Requires-Dist: jaxtyping
|
15
|
-
Requires-Dist: omegaconf
|
16
|
-
Requires-Dist: optax
|
17
|
-
Requires-Dist: pillow
|
18
|
-
Requires-Dist: psutil
|
19
|
-
Requires-Dist: requests
|
20
|
-
Requires-Dist: tensorboard
|
21
|
-
Requires-Dist: types-pillow
|
22
|
-
Requires-Dist: types-psutil
|
23
|
-
Requires-Dist: types-requests
|
24
|
-
Provides-Extra: dev
|
25
|
-
Requires-Dist: black ; extra == 'dev'
|
26
|
-
Requires-Dist: darglint ; extra == 'dev'
|
27
|
-
Requires-Dist: mypy ; extra == 'dev'
|
28
|
-
Requires-Dist: pytest ; extra == 'dev'
|
29
|
-
Requires-Dist: ruff ; extra == 'dev'
|
30
|
-
|
31
|
-
# xax
|
32
|
-
|
33
|
-
JAX library for fast experimentation.
|
34
|
-
|
35
|
-
## Installation
|
36
|
-
|
37
|
-
```bash
|
38
|
-
pip install xax
|
39
|
-
```
|
xax-0.0.3.dist-info/RECORD
DELETED
@@ -1,49 +0,0 @@
|
|
1
|
-
xax/__init__.py,sha256=awE9-hwGPLSx2ywUAUoeBYYm3ztERRhejiZTM8C2ft4,6160
|
2
|
-
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
xax/requirements-dev.txt,sha256=aa9ql8CtmuLTvswvwXYlfAxRL0MEyondkMVcuZuM1A8,56
|
4
|
-
xax/requirements.txt,sha256=eJ8pC2D3CrbjcETujBG-j3y9dt6WMGpT-wQQzAUh2-4,161
|
5
|
-
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
-
xax/core/conf.py,sha256=DrLXM0VED5RexuqF8-LOljGrnU5soYFgugIrvx6tGmE,5841
|
7
|
-
xax/core/state.py,sha256=7lnVSytuhwPfcobPGdjfQ0QxbLgzWQNipKwXchd58QI,2695
|
8
|
-
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
-
xax/nn/functions.py,sha256=ejIltzHagoiKjMEkfVpcIFj_OS3GkMd5JElWUDCrT6Y,2471
|
10
|
-
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
11
|
-
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
xax/task/base.py,sha256=4oueHzToECJYBY-umnrVqDY9tJ6bT18cP97TtaQukBE,7411
|
13
|
-
xax/task/logger.py,sha256=lmZ-cfiBisNskz1JDrQmtydT0uwoAAmr38V8-W8XqiU,29200
|
14
|
-
xax/task/script.py,sha256=oBGnScYa_X284fCajabPCcbaSEIqR8nO4d40dvMv3NQ,1011
|
15
|
-
xax/task/task.py,sha256=mUY7lEug9dW5yrTyT1V5gro5OosnmilMkkWLC-FbI4E,1266
|
16
|
-
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
18
|
-
xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
|
19
|
-
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
20
|
-
xax/task/launchers/staged.py,sha256=jYeT9u58CN4ldV-ltJiQXQglEWOnEckHWnHYjfJQaoY,1102
|
21
|
-
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
|
-
xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
|
23
|
-
xax/task/loggers/state.py,sha256=qyb-q8MdagN7BX-DhKucwoc45tIZJrPuvVDVoysTKC4,1576
|
24
|
-
xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
|
25
|
-
xax/task/loggers/tensorboard.py,sha256=Vt9Vr6i5Bmon7z0vykiQKcx_X9qsYTrRgKf1isLQrEA,7388
|
26
|
-
xax/task/mixins/__init__.py,sha256=NB8CVbx-zxMpWDYeQJJCUIrUOfrNeVc8mE59qXayEsc,685
|
27
|
-
xax/task/mixins/artifacts.py,sha256=W8Y40aqa7QL8mfM5wMRBszLG8Xpv0VZMT3yRnFLfCOU,3568
|
28
|
-
xax/task/mixins/cpu_stats.py,sha256=BeG26vkfy4ePAXZfKfoONyHwpCplK-YrbeeEkVNmF8E,8854
|
29
|
-
xax/task/mixins/data_loader.py,sha256=CWoozUdNd5UgJ2wzVktHTDBDplb0TZJAThq53sYrKsM,6160
|
30
|
-
xax/task/mixins/gpu_stats.py,sha256=HHwBGr56u5MfV_JZDgDvyxnxNgVSr-yNLLfy2tGzyW8,8410
|
31
|
-
xax/task/mixins/logger.py,sha256=lWO3nsEdA92Cd8z_JAiD7VBLgiTr2FS5JK2ZJSajexc,9281
|
32
|
-
xax/task/mixins/process.py,sha256=am6HeAI9Mw7FmrpYX8tMfIzgq-WSYZwm5T6qKuUlOsw,1327
|
33
|
-
xax/task/mixins/runnable.py,sha256=d5-qyIpmNPtbTzE7qFJGGCPSREEDhX1VApUJPNDWye0,1933
|
34
|
-
xax/task/mixins/step_wrapper.py,sha256=Do4eGgZVuqDX9ZGDxQdfn6pRbUnHjQBAkTF0vnNH31E,1472
|
35
|
-
xax/task/mixins/train.py,sha256=blNJEjUL6FnvHUIhIa8a5au2TUw3OEv57hygfIGfNmI,18626
|
36
|
-
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
37
|
-
xax/utils/experiments.py,sha256=EgTTugyBznGhyft04RIA5tDnDGIrib3y3FKa2J6i4sU,26766
|
38
|
-
xax/utils/jax.py,sha256=VzEVB766UyH3_cgN6UP0FkCsDuGlYg5KJj8YJS4yYUk,439
|
39
|
-
xax/utils/logging.py,sha256=CIOvoH7iq70IRvysG-14OuHSgdvoIfzttwMWiw2yu3E,5732
|
40
|
-
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
41
|
-
xax/utils/tensorboard.py,sha256=BEa1OmOl_7eiq9yahMI2vd6sc_VEBiZnzkEPBGv3IVs,7654
|
42
|
-
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
43
|
-
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
|
-
xax/utils/data/collate.py,sha256=fuXTzOxdi-STgtRUU2DyJTF72Yz42kWLr6nlszZs6tM,7003
|
45
|
-
xax-0.0.3.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
46
|
-
xax-0.0.3.dist-info/METADATA,sha256=qLBYRcvPbllBPh-VVLDcpSjaYWQJyQxhwrjlx98DTV0,867
|
47
|
-
xax-0.0.3.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
48
|
-
xax-0.0.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
49
|
-
xax-0.0.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|