xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/utils/text.py
ADDED
@@ -0,0 +1,350 @@
|
|
1
|
+
"""Defines helper functions for displaying text in the terminal."""
|
2
|
+
|
3
|
+
import datetime
|
4
|
+
import itertools
|
5
|
+
import re
|
6
|
+
import sys
|
7
|
+
from typing import Literal
|
8
|
+
|
9
|
+
RESET_SEQ = "\033[0m"
|
10
|
+
REG_COLOR_SEQ = "\033[%dm"
|
11
|
+
BOLD_COLOR_SEQ = "\033[1;%dm"
|
12
|
+
BOLD_SEQ = "\033[1m"
|
13
|
+
|
14
|
+
Color = Literal[
|
15
|
+
"black",
|
16
|
+
"red",
|
17
|
+
"green",
|
18
|
+
"yellow",
|
19
|
+
"blue",
|
20
|
+
"magenta",
|
21
|
+
"cyan",
|
22
|
+
"white",
|
23
|
+
"grey",
|
24
|
+
"light-red",
|
25
|
+
"light-green",
|
26
|
+
"light-yellow",
|
27
|
+
"light-blue",
|
28
|
+
"light-magenta",
|
29
|
+
"light-cyan",
|
30
|
+
]
|
31
|
+
|
32
|
+
COLOR_INDEX: dict[Color, int] = {
|
33
|
+
"black": 30,
|
34
|
+
"red": 31,
|
35
|
+
"green": 32,
|
36
|
+
"yellow": 33,
|
37
|
+
"blue": 34,
|
38
|
+
"magenta": 35,
|
39
|
+
"cyan": 36,
|
40
|
+
"white": 37,
|
41
|
+
"grey": 90,
|
42
|
+
"light-red": 91,
|
43
|
+
"light-green": 92,
|
44
|
+
"light-yellow": 93,
|
45
|
+
"light-blue": 94,
|
46
|
+
"light-magenta": 95,
|
47
|
+
"light-cyan": 96,
|
48
|
+
}
|
49
|
+
|
50
|
+
|
51
|
+
def color_parts(color: Color, bold: bool = False) -> tuple[str, str]:
|
52
|
+
if bold:
|
53
|
+
return BOLD_COLOR_SEQ % COLOR_INDEX[color], RESET_SEQ
|
54
|
+
return REG_COLOR_SEQ % COLOR_INDEX[color], RESET_SEQ
|
55
|
+
|
56
|
+
|
57
|
+
def uncolored(s: str) -> str:
|
58
|
+
return re.sub(r"\033\[[\d;]+m", "", s)
|
59
|
+
|
60
|
+
|
61
|
+
def colored(s: str, color: Color | None = None, bold: bool = False) -> str:
|
62
|
+
if color is None:
|
63
|
+
return s
|
64
|
+
start, end = color_parts(color, bold=bold)
|
65
|
+
return start + s + end
|
66
|
+
|
67
|
+
|
68
|
+
def wrapped(
|
69
|
+
s: str,
|
70
|
+
length: int | None = None,
|
71
|
+
space: str = " ",
|
72
|
+
spaces: str | re.Pattern = r" ",
|
73
|
+
newlines: str | re.Pattern = r"[\n\r]",
|
74
|
+
too_long_suffix: str = "...",
|
75
|
+
) -> list[str]:
|
76
|
+
strings = []
|
77
|
+
lines = re.split(newlines, s.strip(), flags=re.MULTILINE | re.UNICODE)
|
78
|
+
for line in lines:
|
79
|
+
cur_string = []
|
80
|
+
cur_length = 0
|
81
|
+
for part in re.split(spaces, line.strip(), flags=re.MULTILINE | re.UNICODE):
|
82
|
+
if length is None:
|
83
|
+
cur_string.append(part)
|
84
|
+
cur_length += len(space) + len(part)
|
85
|
+
else:
|
86
|
+
if len(part) > length:
|
87
|
+
part = part[: length - len(too_long_suffix)] + too_long_suffix
|
88
|
+
if cur_length + len(part) > length:
|
89
|
+
strings.append(space.join(cur_string))
|
90
|
+
cur_string = [part]
|
91
|
+
cur_length = len(part)
|
92
|
+
else:
|
93
|
+
cur_string.append(part)
|
94
|
+
cur_length += len(space) + len(part)
|
95
|
+
if cur_length > 0:
|
96
|
+
strings.append(space.join(cur_string))
|
97
|
+
return strings
|
98
|
+
|
99
|
+
|
100
|
+
def outlined(
|
101
|
+
s: str,
|
102
|
+
inner: Color | None = None,
|
103
|
+
side: Color | None = None,
|
104
|
+
bold: bool = False,
|
105
|
+
max_length: int | None = None,
|
106
|
+
space: str = " ",
|
107
|
+
spaces: str | re.Pattern = r" ",
|
108
|
+
newlines: str | re.Pattern = r"[\n\r]",
|
109
|
+
) -> str:
|
110
|
+
strs = wrapped(uncolored(s), max_length, space, spaces, newlines)
|
111
|
+
max_len = max(len(s) for s in strs)
|
112
|
+
strs = [f"{s}{' ' * (max_len - len(s))}" for s in strs]
|
113
|
+
strs = [colored(s, inner, bold=bold) for s in strs]
|
114
|
+
strs_with_sides = [f"{colored('│', side)} {s} {colored('│', side)}" for s in strs]
|
115
|
+
top = colored("┌─" + "─" * max_len + "─┐", side)
|
116
|
+
bottom = colored("└─" + "─" * max_len + "─┘", side)
|
117
|
+
return "\n".join([top] + strs_with_sides + [bottom])
|
118
|
+
|
119
|
+
|
120
|
+
def show_info(s: str, important: bool = False) -> None:
|
121
|
+
if important:
|
122
|
+
s = outlined(s, inner="light-cyan", side="cyan", bold=True)
|
123
|
+
else:
|
124
|
+
s = colored(s, "light-cyan", bold=False)
|
125
|
+
sys.stdout.write(s)
|
126
|
+
sys.stdout.write("\n")
|
127
|
+
sys.stdout.flush()
|
128
|
+
|
129
|
+
|
130
|
+
def show_error(s: str, important: bool = False) -> None:
|
131
|
+
if important:
|
132
|
+
s = outlined(s, inner="light-red", side="red", bold=True)
|
133
|
+
else:
|
134
|
+
s = colored(s, "light-red", bold=False)
|
135
|
+
sys.stdout.write(s)
|
136
|
+
sys.stdout.write("\n")
|
137
|
+
sys.stdout.flush()
|
138
|
+
|
139
|
+
|
140
|
+
def show_warning(s: str, important: bool = False) -> None:
|
141
|
+
if important:
|
142
|
+
s = outlined(s, inner="light-yellow", side="yellow", bold=True)
|
143
|
+
else:
|
144
|
+
s = colored(s, "light-yellow", bold=False)
|
145
|
+
sys.stdout.write(s)
|
146
|
+
sys.stdout.write("\n")
|
147
|
+
sys.stdout.flush()
|
148
|
+
|
149
|
+
|
150
|
+
class TextBlock:
|
151
|
+
def __init__(
|
152
|
+
self,
|
153
|
+
text: str,
|
154
|
+
color: Color | None = None,
|
155
|
+
bold: bool = False,
|
156
|
+
width: int | None = None,
|
157
|
+
space: str = " ",
|
158
|
+
spaces: str | re.Pattern = r" ",
|
159
|
+
newlines: str | re.Pattern = r"[\n\r]",
|
160
|
+
too_long_suffix: str = "...",
|
161
|
+
no_sep: bool = False,
|
162
|
+
center: bool = False,
|
163
|
+
) -> None:
|
164
|
+
super().__init__()
|
165
|
+
|
166
|
+
self.width = width
|
167
|
+
self.lines = wrapped(uncolored(text), width, space, spaces, newlines, too_long_suffix)
|
168
|
+
self.color = color
|
169
|
+
self.bold = bold
|
170
|
+
self.no_sep = no_sep
|
171
|
+
self.center = center
|
172
|
+
|
173
|
+
|
174
|
+
def render_text_blocks(
|
175
|
+
blocks: list[list[TextBlock]],
|
176
|
+
newline: str = "\n",
|
177
|
+
align_all_blocks: bool = False,
|
178
|
+
padding: int = 0,
|
179
|
+
) -> str:
|
180
|
+
"""Renders a collection of blocks into a single string.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
blocks: The blocks to render.
|
184
|
+
newline: The string to use as a newline separator.
|
185
|
+
align_all_blocks: If set, aligns the widths for all blocks.
|
186
|
+
padding: The amount of padding to add to each block.
|
187
|
+
|
188
|
+
Returns:
|
189
|
+
The rendered blocks.
|
190
|
+
"""
|
191
|
+
if align_all_blocks:
|
192
|
+
if any(len(row) != len(blocks[0]) for row in blocks):
|
193
|
+
raise ValueError("All rows must have the same number of blocks in order to align them")
|
194
|
+
widths = [[max(len(line) for line in i.lines) if i.width is None else i.width for i in r] for r in blocks]
|
195
|
+
row_widths = [max(i) for i in zip(*widths)]
|
196
|
+
for row in blocks:
|
197
|
+
for i, block in enumerate(row):
|
198
|
+
block.width = row_widths[i]
|
199
|
+
|
200
|
+
def get_widths(row: list[TextBlock], n: int = 0) -> list[int]:
|
201
|
+
return [
|
202
|
+
(max(len(line) for line in block.lines) if block.width is None else block.width) + n + padding
|
203
|
+
for block in row
|
204
|
+
]
|
205
|
+
|
206
|
+
def get_acc_widths(row: list[TextBlock], n: int = 0) -> list[int]:
|
207
|
+
return list(itertools.accumulate(get_widths(row, n)))
|
208
|
+
|
209
|
+
def get_height(row: list[TextBlock]) -> int:
|
210
|
+
return max(len(block.lines) for block in row)
|
211
|
+
|
212
|
+
def pad(s: str, width: int, center: bool) -> str:
|
213
|
+
swidth = len(s)
|
214
|
+
if center:
|
215
|
+
lpad, rpad = (width - swidth) // 2, (width - swidth + 1) // 2
|
216
|
+
else:
|
217
|
+
lpad, rpad = 0, width - swidth
|
218
|
+
return " " * lpad + s + " " * rpad
|
219
|
+
|
220
|
+
lines = []
|
221
|
+
prev_row: list[TextBlock] | None = None
|
222
|
+
for row in blocks:
|
223
|
+
if prev_row is None:
|
224
|
+
lines += ["┌─" + "─┬─".join(["─" * width for width in get_widths(row)]) + "─┐"]
|
225
|
+
elif not all(block.no_sep for block in row):
|
226
|
+
ins, outs = get_acc_widths(prev_row, 3), get_acc_widths(row, 3)
|
227
|
+
segs = sorted([(i, False) for i in ins] + [(i, True) for i in outs])
|
228
|
+
line = ["├"]
|
229
|
+
|
230
|
+
c = 1
|
231
|
+
for i, (s, is_out) in enumerate(segs):
|
232
|
+
if i > 0 and segs[i - 1][0] == s:
|
233
|
+
continue
|
234
|
+
is_in_out = i < len(segs) - 1 and segs[i + 1][0] == s
|
235
|
+
is_last = i == len(segs) - 2 if is_in_out else i == len(segs) - 1
|
236
|
+
|
237
|
+
line += "─" * (s - c)
|
238
|
+
if is_last:
|
239
|
+
if is_in_out:
|
240
|
+
line += "┤"
|
241
|
+
elif is_out:
|
242
|
+
line += "┐"
|
243
|
+
else:
|
244
|
+
line += "┘"
|
245
|
+
else: # noqa: PLR5501
|
246
|
+
if is_in_out:
|
247
|
+
line += "┼"
|
248
|
+
elif is_out:
|
249
|
+
line += "┬"
|
250
|
+
else:
|
251
|
+
line += "┴"
|
252
|
+
c = s + 1
|
253
|
+
|
254
|
+
lines += ["".join(line)]
|
255
|
+
|
256
|
+
for i in range(get_height(row)):
|
257
|
+
lines += [
|
258
|
+
"│ "
|
259
|
+
+ " │ ".join(
|
260
|
+
[
|
261
|
+
(
|
262
|
+
" " * width
|
263
|
+
if i >= len(block.lines)
|
264
|
+
else colored(pad(block.lines[i], width, block.center), block.color, bold=block.bold)
|
265
|
+
)
|
266
|
+
for block, width in zip(row, get_widths(row))
|
267
|
+
]
|
268
|
+
)
|
269
|
+
+ " │"
|
270
|
+
]
|
271
|
+
|
272
|
+
prev_row = row
|
273
|
+
if prev_row is not None:
|
274
|
+
lines += ["└─" + "─┴─".join(["─" * width for width in get_widths(prev_row)]) + "─┘"]
|
275
|
+
|
276
|
+
return newline.join(lines)
|
277
|
+
|
278
|
+
|
279
|
+
def format_timedelta(timedelta: datetime.timedelta, short: bool = False) -> str:
|
280
|
+
"""Formats a delta time to human-readable format.
|
281
|
+
|
282
|
+
Args:
|
283
|
+
timedelta: The delta to format
|
284
|
+
short: If set, uses a shorter format
|
285
|
+
|
286
|
+
Returns:
|
287
|
+
The human-readable time delta
|
288
|
+
"""
|
289
|
+
parts = []
|
290
|
+
if timedelta.days > 0:
|
291
|
+
if short:
|
292
|
+
parts += [f"{timedelta.days}d"]
|
293
|
+
else:
|
294
|
+
parts += [f"{timedelta.days} day" if timedelta.days == 1 else f"{timedelta.days} days"]
|
295
|
+
|
296
|
+
seconds = timedelta.seconds
|
297
|
+
|
298
|
+
if seconds > 60 * 60:
|
299
|
+
hours, seconds = seconds // (60 * 60), seconds % (60 * 60)
|
300
|
+
if short:
|
301
|
+
parts += [f"{hours}h"]
|
302
|
+
else:
|
303
|
+
parts += [f"{hours} hour" if hours == 1 else f"{hours} hours"]
|
304
|
+
|
305
|
+
if seconds > 60:
|
306
|
+
minutes, seconds = seconds // 60, seconds % 60
|
307
|
+
if short:
|
308
|
+
parts += [f"{minutes}m"]
|
309
|
+
else:
|
310
|
+
parts += [f"{minutes} minute" if minutes == 1 else f"{minutes} minutes"]
|
311
|
+
|
312
|
+
if short:
|
313
|
+
parts += [f"{seconds}s"]
|
314
|
+
else:
|
315
|
+
parts += [f"{seconds} second" if seconds == 1 else f"{seconds} seconds"]
|
316
|
+
|
317
|
+
return ", ".join(parts)
|
318
|
+
|
319
|
+
|
320
|
+
def format_datetime(dt: datetime.datetime) -> str:
|
321
|
+
"""Formats a datetime to human-readable format.
|
322
|
+
|
323
|
+
Args:
|
324
|
+
dt: The datetime to format
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
The human-readable datetime
|
328
|
+
"""
|
329
|
+
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
330
|
+
|
331
|
+
|
332
|
+
def camelcase_to_snakecase(s: str) -> str:
|
333
|
+
return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s).lower()
|
334
|
+
|
335
|
+
|
336
|
+
def snakecase_to_camelcase(s: str) -> str:
|
337
|
+
return "".join(word.title() for word in s.split("_"))
|
338
|
+
|
339
|
+
|
340
|
+
def highlight_exception_message(s: str) -> str:
|
341
|
+
s = re.sub(r"(\w+Error)", r"\033[1;31m\1\033[0m", s)
|
342
|
+
s = re.sub(r"(\w+Exception)", r"\033[1;31m\1\033[0m", s)
|
343
|
+
s = re.sub(r"(\w+Warning)", r"\033[1;33m\1\033[0m", s)
|
344
|
+
s = re.sub(r"\^+", r"\033[1;35m\g<0>\033[0m", s)
|
345
|
+
s = re.sub(r"File \"(.+?)\"", r'File "\033[36m\1\033[0m"', s)
|
346
|
+
return s
|
347
|
+
|
348
|
+
|
349
|
+
def is_interactive_session() -> bool:
|
350
|
+
return hasattr(sys, "ps1") or hasattr(sys, "ps2") or sys.stdout.isatty() or sys.stderr.isatty()
|
@@ -0,0 +1,40 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: xax
|
3
|
+
Version: 0.0.5
|
4
|
+
Summary: The xax project
|
5
|
+
Home-page: https://github.com/dpshai/xax
|
6
|
+
Author: Benjamin Bolte
|
7
|
+
Requires-Python: >=3.11
|
8
|
+
Description-Content-Type: text/markdown
|
9
|
+
License-File: LICENSE
|
10
|
+
Requires-Dist: jax
|
11
|
+
Requires-Dist: jaxtyping
|
12
|
+
Requires-Dist: equinox
|
13
|
+
Requires-Dist: optax
|
14
|
+
Requires-Dist: dpshdl
|
15
|
+
Requires-Dist: cloudpickle
|
16
|
+
Requires-Dist: pillow
|
17
|
+
Requires-Dist: omegaconf
|
18
|
+
Requires-Dist: gitpython
|
19
|
+
Requires-Dist: tensorboard
|
20
|
+
Requires-Dist: psutil
|
21
|
+
Requires-Dist: requests
|
22
|
+
Provides-Extra: dev
|
23
|
+
Requires-Dist: black; extra == "dev"
|
24
|
+
Requires-Dist: darglint; extra == "dev"
|
25
|
+
Requires-Dist: mypy; extra == "dev"
|
26
|
+
Requires-Dist: ruff; extra == "dev"
|
27
|
+
Requires-Dist: pytest; extra == "dev"
|
28
|
+
Requires-Dist: types-pillow; extra == "dev"
|
29
|
+
Requires-Dist: types-psutil; extra == "dev"
|
30
|
+
Requires-Dist: types-requests; extra == "dev"
|
31
|
+
|
32
|
+
# xax
|
33
|
+
|
34
|
+
JAX library for fast experimentation.
|
35
|
+
|
36
|
+
## Installation
|
37
|
+
|
38
|
+
```bash
|
39
|
+
pip install xax
|
40
|
+
```
|
@@ -0,0 +1,52 @@
|
|
1
|
+
xax/__init__.py,sha256=3OQTnHGYgaux3i9gTYZxfK8F2zS_hK2QqD-G-Z1TfHQ,7623
|
2
|
+
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
|
+
xax/requirements.txt,sha256=DRn2B9d3mAr57-U3IOIrKm2nYz8H3cYgDy6EIC3SsuE,266
|
5
|
+
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
xax/core/conf.py,sha256=hwgc5sJw0YRSegQLLrmIDtscev-H_a2ST1-V6BJ5aec,5915
|
7
|
+
xax/core/state.py,sha256=7lnVSytuhwPfcobPGdjfQ0QxbLgzWQNipKwXchd58QI,2695
|
8
|
+
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
+
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
|
+
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
11
|
+
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
12
|
+
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
+
xax/task/base.py,sha256=n82Sw-kMLr-WZzh0c_vAAQ2b-DHRYs0U8biPRonBxKU,7252
|
14
|
+
xax/task/logger.py,sha256=MAFIgd6yO0pD3gJHfKTwUDcwaM8DZD3AZtFLvrQtlFo,26740
|
15
|
+
xax/task/script.py,sha256=oBGnScYa_X284fCajabPCcbaSEIqR8nO4d40dvMv3NQ,1011
|
16
|
+
xax/task/task.py,sha256=X7TV_gt6C4m_-Il22Uyr5iMm-eh15oH5v1dl96sv1go,1295
|
17
|
+
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
|
+
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
19
|
+
xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
|
20
|
+
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
21
|
+
xax/task/launchers/staged.py,sha256=jYeT9u58CN4ldV-ltJiQXQglEWOnEckHWnHYjfJQaoY,1102
|
22
|
+
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
+
xax/task/loggers/callback.py,sha256=reaRuJs5iB6WWNgh3_tsuz_QPAlBC-5Ed2wCG_6Wj4M,2075
|
24
|
+
xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
|
25
|
+
xax/task/loggers/state.py,sha256=qyb-q8MdagN7BX-DhKucwoc45tIZJrPuvVDVoysTKC4,1576
|
26
|
+
xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
|
27
|
+
xax/task/loggers/tensorboard.py,sha256=DMYRDCQ9c-xHqO4kkZvc1-53PXCf2gX0aRiiAQDtHJ0,7293
|
28
|
+
xax/task/mixins/__init__.py,sha256=NkSAjMN5jpXE6LROIwMzX60z7UsTBpGs624_mNUWquo,745
|
29
|
+
xax/task/mixins/artifacts.py,sha256=G0984WuXII_R13IlJZn9En7iM83ISXKjeVYvn7j4wBs,3754
|
30
|
+
xax/task/mixins/checkpointing.py,sha256=JV91b5xyBUyZIbR3S-5UkBZNoAZYCnWx7Y-ayuU0lHQ,7989
|
31
|
+
xax/task/mixins/cpu_stats.py,sha256=Lqskt1t4usE6UslhANjwB0ZKOYmaC4dm9dnVKa6ERdA,8924
|
32
|
+
xax/task/mixins/data_loader.py,sha256=BPs0sYdctesnhS9nQ1rvT77MzLXznw5E4tAzWT1PpJY,5998
|
33
|
+
xax/task/mixins/gpu_stats.py,sha256=tFTNmtl9iMiLiYJSPg7gHR-ZxOP4P_ynzSmYNIAUoRw,8431
|
34
|
+
xax/task/mixins/logger.py,sha256=6XkjP_YUGY2CiDry0kDm1f9jqzJaLa1bPVYYnGjvSBU,2049
|
35
|
+
xax/task/mixins/process.py,sha256=HQAvEruvvfcS_IThrM4hKhFHZCAN2kFY_vEaZGLeZS8,1428
|
36
|
+
xax/task/mixins/runnable.py,sha256=d5-qyIpmNPtbTzE7qFJGGCPSREEDhX1VApUJPNDWye0,1933
|
37
|
+
xax/task/mixins/step_wrapper.py,sha256=Do4eGgZVuqDX9ZGDxQdfn6pRbUnHjQBAkTF0vnNH31E,1472
|
38
|
+
xax/task/mixins/train.py,sha256=Xeb0N9j-Znz5QnMDCXDGPqUSKMNLJkd8oF8giN45l2U,20099
|
39
|
+
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
|
+
xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
|
41
|
+
xax/utils/jax.py,sha256=VzEVB766UyH3_cgN6UP0FkCsDuGlYg5KJj8YJS4yYUk,439
|
42
|
+
xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
|
43
|
+
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
44
|
+
xax/utils/tensorboard.py,sha256=XqxUlryFVsb75jE36uLcuoUhSr3nWg_-dzji2h6U_rI,8245
|
45
|
+
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
46
|
+
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
|
+
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
48
|
+
xax-0.0.5.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
49
|
+
xax-0.0.5.dist-info/METADATA,sha256=VCiQmbjwZtiuORVyB0dloFTgLWtnK4o3FaolNWvf-A4,937
|
50
|
+
xax-0.0.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
51
|
+
xax-0.0.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
52
|
+
xax-0.0.5.dist-info/RECORD,,
|
@@ -0,0 +1 @@
|
|
1
|
+
xax
|
examples/mnist.py
DELETED
@@ -1,148 +0,0 @@
|
|
1
|
-
"""MNIST example in Jax."""
|
2
|
-
|
3
|
-
import array
|
4
|
-
import gzip
|
5
|
-
import itertools
|
6
|
-
import os
|
7
|
-
import struct
|
8
|
-
import time
|
9
|
-
import urllib.request
|
10
|
-
from typing import Iterator
|
11
|
-
|
12
|
-
import jax.numpy as jnp
|
13
|
-
import numpy as np
|
14
|
-
import numpy.random as npr
|
15
|
-
from jax import grad, jit, random
|
16
|
-
from jax.example_libraries import optimizers, stax
|
17
|
-
from jax.example_libraries.optimizers import OptimizerState
|
18
|
-
from jax.example_libraries.stax import Dense, LogSoftmax, Relu
|
19
|
-
from jaxtyping import ArrayLike
|
20
|
-
|
21
|
-
_DATA = "/tmp/jax_example_data/"
|
22
|
-
|
23
|
-
|
24
|
-
def _download(url: str, filename: str) -> None:
|
25
|
-
if not os.path.exists(_DATA):
|
26
|
-
os.makedirs(_DATA)
|
27
|
-
out_file = os.path.join(_DATA, filename)
|
28
|
-
if not os.path.isfile(out_file):
|
29
|
-
urllib.request.urlretrieve(url, out_file)
|
30
|
-
print(f"downloaded {url} to {_DATA}")
|
31
|
-
|
32
|
-
|
33
|
-
def _partial_flatten(x: np.ndarray) -> np.ndarray:
|
34
|
-
return np.reshape(x, (x.shape[0], -1))
|
35
|
-
|
36
|
-
|
37
|
-
def _one_hot(x: np.ndarray, k: int, dtype: type = np.float32) -> np.ndarray:
|
38
|
-
return np.array(x[:, None] == np.arange(k), dtype)
|
39
|
-
|
40
|
-
|
41
|
-
def mnist_raw() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
42
|
-
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
43
|
-
|
44
|
-
def parse_labels(filename: str) -> np.ndarray:
|
45
|
-
with gzip.open(filename, "rb") as fh:
|
46
|
-
_ = struct.unpack(">II", fh.read(8))
|
47
|
-
return np.array(array.array("B", fh.read()), dtype=np.uint8)
|
48
|
-
|
49
|
-
def parse_images(filename: str) -> np.ndarray:
|
50
|
-
with gzip.open(filename, "rb") as fh:
|
51
|
-
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
|
52
|
-
return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols)
|
53
|
-
|
54
|
-
for filename in [
|
55
|
-
"train-images-idx3-ubyte.gz",
|
56
|
-
"train-labels-idx1-ubyte.gz",
|
57
|
-
"t10k-images-idx3-ubyte.gz",
|
58
|
-
"t10k-labels-idx1-ubyte.gz",
|
59
|
-
]:
|
60
|
-
_download(base_url + filename, filename)
|
61
|
-
|
62
|
-
train_images = parse_images(os.path.join(_DATA, "train-images-idx3-ubyte.gz"))
|
63
|
-
train_labels = parse_labels(os.path.join(_DATA, "train-labels-idx1-ubyte.gz"))
|
64
|
-
test_images = parse_images(os.path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
|
65
|
-
test_labels = parse_labels(os.path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))
|
66
|
-
|
67
|
-
return train_images, train_labels, test_images, test_labels
|
68
|
-
|
69
|
-
|
70
|
-
def mnist(permute_train: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
71
|
-
train_images, train_labels, test_images, test_labels = mnist_raw()
|
72
|
-
|
73
|
-
train_images = _partial_flatten(train_images) / np.float32(255.0)
|
74
|
-
test_images = _partial_flatten(test_images) / np.float32(255.0)
|
75
|
-
train_labels = _one_hot(train_labels, 10)
|
76
|
-
test_labels = _one_hot(test_labels, 10)
|
77
|
-
|
78
|
-
if permute_train:
|
79
|
-
perm = np.random.RandomState(0).permutation(train_images.shape[0])
|
80
|
-
train_images = train_images[perm]
|
81
|
-
train_labels = train_labels[perm]
|
82
|
-
|
83
|
-
return train_images, train_labels, test_images, test_labels
|
84
|
-
|
85
|
-
|
86
|
-
def loss(params: tuple[ArrayLike, ArrayLike], batch: tuple[ArrayLike, ArrayLike]) -> ArrayLike:
|
87
|
-
inputs, targets = batch
|
88
|
-
preds = predict(params, inputs)
|
89
|
-
return -jnp.mean(jnp.sum(preds * targets, axis=1))
|
90
|
-
|
91
|
-
|
92
|
-
def accuracy(params: tuple[ArrayLike, ArrayLike], batch: tuple[ArrayLike, ArrayLike]) -> ArrayLike:
|
93
|
-
inputs, targets = batch
|
94
|
-
target_class = jnp.argmax(targets, axis=1)
|
95
|
-
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
|
96
|
-
return jnp.mean(predicted_class == target_class)
|
97
|
-
|
98
|
-
|
99
|
-
init_random_params, predict = stax.serial(Dense(1024), Relu, Dense(1024), Relu, Dense(10), LogSoftmax)
|
100
|
-
|
101
|
-
if __name__ == "__main__":
|
102
|
-
# python -m examples.mnist
|
103
|
-
rng = random.PRNGKey(0)
|
104
|
-
|
105
|
-
step_size = 0.001
|
106
|
-
num_epochs = 10
|
107
|
-
batch_size = 128
|
108
|
-
momentum_mass = 0.9
|
109
|
-
|
110
|
-
train_images, train_labels, test_images, test_labels = mnist()
|
111
|
-
num_train = train_images.shape[0]
|
112
|
-
num_complete_batches, leftover = divmod(num_train, batch_size)
|
113
|
-
num_batches = num_complete_batches + bool(leftover)
|
114
|
-
|
115
|
-
def data_stream() -> Iterator[tuple[np.ndarray, np.ndarray]]:
|
116
|
-
rng = npr.RandomState(0)
|
117
|
-
while True:
|
118
|
-
perm = rng.permutation(num_train)
|
119
|
-
for i in range(num_batches):
|
120
|
-
batch_idx = perm[i * batch_size : (i + 1) * batch_size]
|
121
|
-
yield train_images[batch_idx], train_labels[batch_idx]
|
122
|
-
|
123
|
-
batches = data_stream()
|
124
|
-
|
125
|
-
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
|
126
|
-
|
127
|
-
@jit
|
128
|
-
def update(i: int, opt_state: OptimizerState, batch: tuple[ArrayLike, ArrayLike]) -> OptimizerState:
|
129
|
-
params = get_params(opt_state)
|
130
|
-
return opt_update(i, grad(loss)(params, batch), opt_state)
|
131
|
-
|
132
|
-
_, init_params = init_random_params(rng, (-1, 28 * 28))
|
133
|
-
opt_state = opt_init(init_params)
|
134
|
-
itercount = itertools.count()
|
135
|
-
|
136
|
-
print("\nStarting training...")
|
137
|
-
for epoch in range(num_epochs):
|
138
|
-
start_time = time.time()
|
139
|
-
for _ in range(num_batches):
|
140
|
-
opt_state = update(next(itercount), opt_state, next(batches))
|
141
|
-
epoch_time = time.time() - start_time
|
142
|
-
|
143
|
-
params = get_params(opt_state)
|
144
|
-
train_acc = accuracy(params, (train_images, train_labels))
|
145
|
-
test_acc = accuracy(params, (test_images, test_labels))
|
146
|
-
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
|
147
|
-
print(f"Training set accuracy {train_acc}")
|
148
|
-
print(f"Test set accuracy {test_acc}")
|
xax-0.0.1.dist-info/METADATA
DELETED
@@ -1,21 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.1
|
2
|
-
Name: xax
|
3
|
-
Version: 0.0.1
|
4
|
-
Summary: The xax project
|
5
|
-
Home-page: https://github.com/dpshai/xax
|
6
|
-
Author: Benjamin Bolte
|
7
|
-
Requires-Python: >=3.11
|
8
|
-
Description-Content-Type: text/markdown
|
9
|
-
License-File: LICENSE
|
10
|
-
Requires-Dist: jax
|
11
|
-
Requires-Dist: jaxtyping
|
12
|
-
Provides-Extra: dev
|
13
|
-
Requires-Dist: black ; extra == 'dev'
|
14
|
-
Requires-Dist: darglint ; extra == 'dev'
|
15
|
-
Requires-Dist: mypy ; extra == 'dev'
|
16
|
-
Requires-Dist: pytest ; extra == 'dev'
|
17
|
-
Requires-Dist: ruff ; extra == 'dev'
|
18
|
-
|
19
|
-
# xax
|
20
|
-
|
21
|
-
JAX library for fast experimentation.
|
xax-0.0.1.dist-info/RECORD
DELETED
@@ -1,9 +0,0 @@
|
|
1
|
-
examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
examples/mnist.py,sha256=vIKKlxS153hWM0Z8xM3HzbbF1leWrgRBPzL_fDoTMko,5366
|
3
|
-
xax/__init__.py,sha256=sXLh7g3KC4QCFxcZGBTpG2scR7hmmBsMjq6LqRptkRg,22
|
4
|
-
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
xax-0.0.1.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
6
|
-
xax-0.0.1.dist-info/METADATA,sha256=F9reokda5s_h_HTX5dKfUKRxNKLoAWNs5nnIR-g9kj0,524
|
7
|
-
xax-0.0.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
8
|
-
xax-0.0.1.dist-info/top_level.txt,sha256=rD77ScL4NoinNhAT1cDqumVhFa2AoqoIrgBKbwSsOzY,13
|
9
|
-
xax-0.0.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|