xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/utils/text.py ADDED
@@ -0,0 +1,350 @@
1
+ """Defines helper functions for displaying text in the terminal."""
2
+
3
+ import datetime
4
+ import itertools
5
+ import re
6
+ import sys
7
+ from typing import Literal
8
+
9
+ RESET_SEQ = "\033[0m"
10
+ REG_COLOR_SEQ = "\033[%dm"
11
+ BOLD_COLOR_SEQ = "\033[1;%dm"
12
+ BOLD_SEQ = "\033[1m"
13
+
14
+ Color = Literal[
15
+ "black",
16
+ "red",
17
+ "green",
18
+ "yellow",
19
+ "blue",
20
+ "magenta",
21
+ "cyan",
22
+ "white",
23
+ "grey",
24
+ "light-red",
25
+ "light-green",
26
+ "light-yellow",
27
+ "light-blue",
28
+ "light-magenta",
29
+ "light-cyan",
30
+ ]
31
+
32
+ COLOR_INDEX: dict[Color, int] = {
33
+ "black": 30,
34
+ "red": 31,
35
+ "green": 32,
36
+ "yellow": 33,
37
+ "blue": 34,
38
+ "magenta": 35,
39
+ "cyan": 36,
40
+ "white": 37,
41
+ "grey": 90,
42
+ "light-red": 91,
43
+ "light-green": 92,
44
+ "light-yellow": 93,
45
+ "light-blue": 94,
46
+ "light-magenta": 95,
47
+ "light-cyan": 96,
48
+ }
49
+
50
+
51
+ def color_parts(color: Color, bold: bool = False) -> tuple[str, str]:
52
+ if bold:
53
+ return BOLD_COLOR_SEQ % COLOR_INDEX[color], RESET_SEQ
54
+ return REG_COLOR_SEQ % COLOR_INDEX[color], RESET_SEQ
55
+
56
+
57
+ def uncolored(s: str) -> str:
58
+ return re.sub(r"\033\[[\d;]+m", "", s)
59
+
60
+
61
+ def colored(s: str, color: Color | None = None, bold: bool = False) -> str:
62
+ if color is None:
63
+ return s
64
+ start, end = color_parts(color, bold=bold)
65
+ return start + s + end
66
+
67
+
68
+ def wrapped(
69
+ s: str,
70
+ length: int | None = None,
71
+ space: str = " ",
72
+ spaces: str | re.Pattern = r" ",
73
+ newlines: str | re.Pattern = r"[\n\r]",
74
+ too_long_suffix: str = "...",
75
+ ) -> list[str]:
76
+ strings = []
77
+ lines = re.split(newlines, s.strip(), flags=re.MULTILINE | re.UNICODE)
78
+ for line in lines:
79
+ cur_string = []
80
+ cur_length = 0
81
+ for part in re.split(spaces, line.strip(), flags=re.MULTILINE | re.UNICODE):
82
+ if length is None:
83
+ cur_string.append(part)
84
+ cur_length += len(space) + len(part)
85
+ else:
86
+ if len(part) > length:
87
+ part = part[: length - len(too_long_suffix)] + too_long_suffix
88
+ if cur_length + len(part) > length:
89
+ strings.append(space.join(cur_string))
90
+ cur_string = [part]
91
+ cur_length = len(part)
92
+ else:
93
+ cur_string.append(part)
94
+ cur_length += len(space) + len(part)
95
+ if cur_length > 0:
96
+ strings.append(space.join(cur_string))
97
+ return strings
98
+
99
+
100
+ def outlined(
101
+ s: str,
102
+ inner: Color | None = None,
103
+ side: Color | None = None,
104
+ bold: bool = False,
105
+ max_length: int | None = None,
106
+ space: str = " ",
107
+ spaces: str | re.Pattern = r" ",
108
+ newlines: str | re.Pattern = r"[\n\r]",
109
+ ) -> str:
110
+ strs = wrapped(uncolored(s), max_length, space, spaces, newlines)
111
+ max_len = max(len(s) for s in strs)
112
+ strs = [f"{s}{' ' * (max_len - len(s))}" for s in strs]
113
+ strs = [colored(s, inner, bold=bold) for s in strs]
114
+ strs_with_sides = [f"{colored('│', side)} {s} {colored('│', side)}" for s in strs]
115
+ top = colored("┌─" + "─" * max_len + "─┐", side)
116
+ bottom = colored("└─" + "─" * max_len + "─┘", side)
117
+ return "\n".join([top] + strs_with_sides + [bottom])
118
+
119
+
120
+ def show_info(s: str, important: bool = False) -> None:
121
+ if important:
122
+ s = outlined(s, inner="light-cyan", side="cyan", bold=True)
123
+ else:
124
+ s = colored(s, "light-cyan", bold=False)
125
+ sys.stdout.write(s)
126
+ sys.stdout.write("\n")
127
+ sys.stdout.flush()
128
+
129
+
130
+ def show_error(s: str, important: bool = False) -> None:
131
+ if important:
132
+ s = outlined(s, inner="light-red", side="red", bold=True)
133
+ else:
134
+ s = colored(s, "light-red", bold=False)
135
+ sys.stdout.write(s)
136
+ sys.stdout.write("\n")
137
+ sys.stdout.flush()
138
+
139
+
140
+ def show_warning(s: str, important: bool = False) -> None:
141
+ if important:
142
+ s = outlined(s, inner="light-yellow", side="yellow", bold=True)
143
+ else:
144
+ s = colored(s, "light-yellow", bold=False)
145
+ sys.stdout.write(s)
146
+ sys.stdout.write("\n")
147
+ sys.stdout.flush()
148
+
149
+
150
+ class TextBlock:
151
+ def __init__(
152
+ self,
153
+ text: str,
154
+ color: Color | None = None,
155
+ bold: bool = False,
156
+ width: int | None = None,
157
+ space: str = " ",
158
+ spaces: str | re.Pattern = r" ",
159
+ newlines: str | re.Pattern = r"[\n\r]",
160
+ too_long_suffix: str = "...",
161
+ no_sep: bool = False,
162
+ center: bool = False,
163
+ ) -> None:
164
+ super().__init__()
165
+
166
+ self.width = width
167
+ self.lines = wrapped(uncolored(text), width, space, spaces, newlines, too_long_suffix)
168
+ self.color = color
169
+ self.bold = bold
170
+ self.no_sep = no_sep
171
+ self.center = center
172
+
173
+
174
+ def render_text_blocks(
175
+ blocks: list[list[TextBlock]],
176
+ newline: str = "\n",
177
+ align_all_blocks: bool = False,
178
+ padding: int = 0,
179
+ ) -> str:
180
+ """Renders a collection of blocks into a single string.
181
+
182
+ Args:
183
+ blocks: The blocks to render.
184
+ newline: The string to use as a newline separator.
185
+ align_all_blocks: If set, aligns the widths for all blocks.
186
+ padding: The amount of padding to add to each block.
187
+
188
+ Returns:
189
+ The rendered blocks.
190
+ """
191
+ if align_all_blocks:
192
+ if any(len(row) != len(blocks[0]) for row in blocks):
193
+ raise ValueError("All rows must have the same number of blocks in order to align them")
194
+ widths = [[max(len(line) for line in i.lines) if i.width is None else i.width for i in r] for r in blocks]
195
+ row_widths = [max(i) for i in zip(*widths)]
196
+ for row in blocks:
197
+ for i, block in enumerate(row):
198
+ block.width = row_widths[i]
199
+
200
+ def get_widths(row: list[TextBlock], n: int = 0) -> list[int]:
201
+ return [
202
+ (max(len(line) for line in block.lines) if block.width is None else block.width) + n + padding
203
+ for block in row
204
+ ]
205
+
206
+ def get_acc_widths(row: list[TextBlock], n: int = 0) -> list[int]:
207
+ return list(itertools.accumulate(get_widths(row, n)))
208
+
209
+ def get_height(row: list[TextBlock]) -> int:
210
+ return max(len(block.lines) for block in row)
211
+
212
+ def pad(s: str, width: int, center: bool) -> str:
213
+ swidth = len(s)
214
+ if center:
215
+ lpad, rpad = (width - swidth) // 2, (width - swidth + 1) // 2
216
+ else:
217
+ lpad, rpad = 0, width - swidth
218
+ return " " * lpad + s + " " * rpad
219
+
220
+ lines = []
221
+ prev_row: list[TextBlock] | None = None
222
+ for row in blocks:
223
+ if prev_row is None:
224
+ lines += ["┌─" + "─┬─".join(["─" * width for width in get_widths(row)]) + "─┐"]
225
+ elif not all(block.no_sep for block in row):
226
+ ins, outs = get_acc_widths(prev_row, 3), get_acc_widths(row, 3)
227
+ segs = sorted([(i, False) for i in ins] + [(i, True) for i in outs])
228
+ line = ["├"]
229
+
230
+ c = 1
231
+ for i, (s, is_out) in enumerate(segs):
232
+ if i > 0 and segs[i - 1][0] == s:
233
+ continue
234
+ is_in_out = i < len(segs) - 1 and segs[i + 1][0] == s
235
+ is_last = i == len(segs) - 2 if is_in_out else i == len(segs) - 1
236
+
237
+ line += "─" * (s - c)
238
+ if is_last:
239
+ if is_in_out:
240
+ line += "┤"
241
+ elif is_out:
242
+ line += "┐"
243
+ else:
244
+ line += "┘"
245
+ else: # noqa: PLR5501
246
+ if is_in_out:
247
+ line += "┼"
248
+ elif is_out:
249
+ line += "┬"
250
+ else:
251
+ line += "┴"
252
+ c = s + 1
253
+
254
+ lines += ["".join(line)]
255
+
256
+ for i in range(get_height(row)):
257
+ lines += [
258
+ "│ "
259
+ + " │ ".join(
260
+ [
261
+ (
262
+ " " * width
263
+ if i >= len(block.lines)
264
+ else colored(pad(block.lines[i], width, block.center), block.color, bold=block.bold)
265
+ )
266
+ for block, width in zip(row, get_widths(row))
267
+ ]
268
+ )
269
+ + " │"
270
+ ]
271
+
272
+ prev_row = row
273
+ if prev_row is not None:
274
+ lines += ["└─" + "─┴─".join(["─" * width for width in get_widths(prev_row)]) + "─┘"]
275
+
276
+ return newline.join(lines)
277
+
278
+
279
+ def format_timedelta(timedelta: datetime.timedelta, short: bool = False) -> str:
280
+ """Formats a delta time to human-readable format.
281
+
282
+ Args:
283
+ timedelta: The delta to format
284
+ short: If set, uses a shorter format
285
+
286
+ Returns:
287
+ The human-readable time delta
288
+ """
289
+ parts = []
290
+ if timedelta.days > 0:
291
+ if short:
292
+ parts += [f"{timedelta.days}d"]
293
+ else:
294
+ parts += [f"{timedelta.days} day" if timedelta.days == 1 else f"{timedelta.days} days"]
295
+
296
+ seconds = timedelta.seconds
297
+
298
+ if seconds > 60 * 60:
299
+ hours, seconds = seconds // (60 * 60), seconds % (60 * 60)
300
+ if short:
301
+ parts += [f"{hours}h"]
302
+ else:
303
+ parts += [f"{hours} hour" if hours == 1 else f"{hours} hours"]
304
+
305
+ if seconds > 60:
306
+ minutes, seconds = seconds // 60, seconds % 60
307
+ if short:
308
+ parts += [f"{minutes}m"]
309
+ else:
310
+ parts += [f"{minutes} minute" if minutes == 1 else f"{minutes} minutes"]
311
+
312
+ if short:
313
+ parts += [f"{seconds}s"]
314
+ else:
315
+ parts += [f"{seconds} second" if seconds == 1 else f"{seconds} seconds"]
316
+
317
+ return ", ".join(parts)
318
+
319
+
320
+ def format_datetime(dt: datetime.datetime) -> str:
321
+ """Formats a datetime to human-readable format.
322
+
323
+ Args:
324
+ dt: The datetime to format
325
+
326
+ Returns:
327
+ The human-readable datetime
328
+ """
329
+ return dt.strftime("%Y-%m-%d %H:%M:%S")
330
+
331
+
332
+ def camelcase_to_snakecase(s: str) -> str:
333
+ return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s).lower()
334
+
335
+
336
+ def snakecase_to_camelcase(s: str) -> str:
337
+ return "".join(word.title() for word in s.split("_"))
338
+
339
+
340
+ def highlight_exception_message(s: str) -> str:
341
+ s = re.sub(r"(\w+Error)", r"\033[1;31m\1\033[0m", s)
342
+ s = re.sub(r"(\w+Exception)", r"\033[1;31m\1\033[0m", s)
343
+ s = re.sub(r"(\w+Warning)", r"\033[1;33m\1\033[0m", s)
344
+ s = re.sub(r"\^+", r"\033[1;35m\g<0>\033[0m", s)
345
+ s = re.sub(r"File \"(.+?)\"", r'File "\033[36m\1\033[0m"', s)
346
+ return s
347
+
348
+
349
+ def is_interactive_session() -> bool:
350
+ return hasattr(sys, "ps1") or hasattr(sys, "ps2") or sys.stdout.isatty() or sys.stderr.isatty()
@@ -0,0 +1,40 @@
1
+ Metadata-Version: 2.1
2
+ Name: xax
3
+ Version: 0.0.5
4
+ Summary: The xax project
5
+ Home-page: https://github.com/dpshai/xax
6
+ Author: Benjamin Bolte
7
+ Requires-Python: >=3.11
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: jax
11
+ Requires-Dist: jaxtyping
12
+ Requires-Dist: equinox
13
+ Requires-Dist: optax
14
+ Requires-Dist: dpshdl
15
+ Requires-Dist: cloudpickle
16
+ Requires-Dist: pillow
17
+ Requires-Dist: omegaconf
18
+ Requires-Dist: gitpython
19
+ Requires-Dist: tensorboard
20
+ Requires-Dist: psutil
21
+ Requires-Dist: requests
22
+ Provides-Extra: dev
23
+ Requires-Dist: black; extra == "dev"
24
+ Requires-Dist: darglint; extra == "dev"
25
+ Requires-Dist: mypy; extra == "dev"
26
+ Requires-Dist: ruff; extra == "dev"
27
+ Requires-Dist: pytest; extra == "dev"
28
+ Requires-Dist: types-pillow; extra == "dev"
29
+ Requires-Dist: types-psutil; extra == "dev"
30
+ Requires-Dist: types-requests; extra == "dev"
31
+
32
+ # xax
33
+
34
+ JAX library for fast experimentation.
35
+
36
+ ## Installation
37
+
38
+ ```bash
39
+ pip install xax
40
+ ```
@@ -0,0 +1,52 @@
1
+ xax/__init__.py,sha256=3OQTnHGYgaux3i9gTYZxfK8F2zS_hK2QqD-G-Z1TfHQ,7623
2
+ xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
+ xax/requirements.txt,sha256=DRn2B9d3mAr57-U3IOIrKm2nYz8H3cYgDy6EIC3SsuE,266
5
+ xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ xax/core/conf.py,sha256=hwgc5sJw0YRSegQLLrmIDtscev-H_a2ST1-V6BJ5aec,5915
7
+ xax/core/state.py,sha256=7lnVSytuhwPfcobPGdjfQ0QxbLgzWQNipKwXchd58QI,2695
8
+ xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
+ xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
11
+ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
12
+ xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ xax/task/base.py,sha256=n82Sw-kMLr-WZzh0c_vAAQ2b-DHRYs0U8biPRonBxKU,7252
14
+ xax/task/logger.py,sha256=MAFIgd6yO0pD3gJHfKTwUDcwaM8DZD3AZtFLvrQtlFo,26740
15
+ xax/task/script.py,sha256=oBGnScYa_X284fCajabPCcbaSEIqR8nO4d40dvMv3NQ,1011
16
+ xax/task/task.py,sha256=X7TV_gt6C4m_-Il22Uyr5iMm-eh15oH5v1dl96sv1go,1295
17
+ xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
19
+ xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
20
+ xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
21
+ xax/task/launchers/staged.py,sha256=jYeT9u58CN4ldV-ltJiQXQglEWOnEckHWnHYjfJQaoY,1102
22
+ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
+ xax/task/loggers/callback.py,sha256=reaRuJs5iB6WWNgh3_tsuz_QPAlBC-5Ed2wCG_6Wj4M,2075
24
+ xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
25
+ xax/task/loggers/state.py,sha256=qyb-q8MdagN7BX-DhKucwoc45tIZJrPuvVDVoysTKC4,1576
26
+ xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
27
+ xax/task/loggers/tensorboard.py,sha256=DMYRDCQ9c-xHqO4kkZvc1-53PXCf2gX0aRiiAQDtHJ0,7293
28
+ xax/task/mixins/__init__.py,sha256=NkSAjMN5jpXE6LROIwMzX60z7UsTBpGs624_mNUWquo,745
29
+ xax/task/mixins/artifacts.py,sha256=G0984WuXII_R13IlJZn9En7iM83ISXKjeVYvn7j4wBs,3754
30
+ xax/task/mixins/checkpointing.py,sha256=JV91b5xyBUyZIbR3S-5UkBZNoAZYCnWx7Y-ayuU0lHQ,7989
31
+ xax/task/mixins/cpu_stats.py,sha256=Lqskt1t4usE6UslhANjwB0ZKOYmaC4dm9dnVKa6ERdA,8924
32
+ xax/task/mixins/data_loader.py,sha256=BPs0sYdctesnhS9nQ1rvT77MzLXznw5E4tAzWT1PpJY,5998
33
+ xax/task/mixins/gpu_stats.py,sha256=tFTNmtl9iMiLiYJSPg7gHR-ZxOP4P_ynzSmYNIAUoRw,8431
34
+ xax/task/mixins/logger.py,sha256=6XkjP_YUGY2CiDry0kDm1f9jqzJaLa1bPVYYnGjvSBU,2049
35
+ xax/task/mixins/process.py,sha256=HQAvEruvvfcS_IThrM4hKhFHZCAN2kFY_vEaZGLeZS8,1428
36
+ xax/task/mixins/runnable.py,sha256=d5-qyIpmNPtbTzE7qFJGGCPSREEDhX1VApUJPNDWye0,1933
37
+ xax/task/mixins/step_wrapper.py,sha256=Do4eGgZVuqDX9ZGDxQdfn6pRbUnHjQBAkTF0vnNH31E,1472
38
+ xax/task/mixins/train.py,sha256=Xeb0N9j-Znz5QnMDCXDGPqUSKMNLJkd8oF8giN45l2U,20099
39
+ xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
+ xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
41
+ xax/utils/jax.py,sha256=VzEVB766UyH3_cgN6UP0FkCsDuGlYg5KJj8YJS4yYUk,439
42
+ xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
43
+ xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
44
+ xax/utils/tensorboard.py,sha256=XqxUlryFVsb75jE36uLcuoUhSr3nWg_-dzji2h6U_rI,8245
45
+ xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
46
+ xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
+ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
48
+ xax-0.0.5.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
49
+ xax-0.0.5.dist-info/METADATA,sha256=VCiQmbjwZtiuORVyB0dloFTgLWtnK4o3FaolNWvf-A4,937
50
+ xax-0.0.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
51
+ xax-0.0.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
52
+ xax-0.0.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.42.0)
2
+ Generator: setuptools (75.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -0,0 +1 @@
1
+ xax
examples/mnist.py DELETED
@@ -1,148 +0,0 @@
1
- """MNIST example in Jax."""
2
-
3
- import array
4
- import gzip
5
- import itertools
6
- import os
7
- import struct
8
- import time
9
- import urllib.request
10
- from typing import Iterator
11
-
12
- import jax.numpy as jnp
13
- import numpy as np
14
- import numpy.random as npr
15
- from jax import grad, jit, random
16
- from jax.example_libraries import optimizers, stax
17
- from jax.example_libraries.optimizers import OptimizerState
18
- from jax.example_libraries.stax import Dense, LogSoftmax, Relu
19
- from jaxtyping import ArrayLike
20
-
21
- _DATA = "/tmp/jax_example_data/"
22
-
23
-
24
- def _download(url: str, filename: str) -> None:
25
- if not os.path.exists(_DATA):
26
- os.makedirs(_DATA)
27
- out_file = os.path.join(_DATA, filename)
28
- if not os.path.isfile(out_file):
29
- urllib.request.urlretrieve(url, out_file)
30
- print(f"downloaded {url} to {_DATA}")
31
-
32
-
33
- def _partial_flatten(x: np.ndarray) -> np.ndarray:
34
- return np.reshape(x, (x.shape[0], -1))
35
-
36
-
37
- def _one_hot(x: np.ndarray, k: int, dtype: type = np.float32) -> np.ndarray:
38
- return np.array(x[:, None] == np.arange(k), dtype)
39
-
40
-
41
- def mnist_raw() -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
42
- base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
43
-
44
- def parse_labels(filename: str) -> np.ndarray:
45
- with gzip.open(filename, "rb") as fh:
46
- _ = struct.unpack(">II", fh.read(8))
47
- return np.array(array.array("B", fh.read()), dtype=np.uint8)
48
-
49
- def parse_images(filename: str) -> np.ndarray:
50
- with gzip.open(filename, "rb") as fh:
51
- _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
52
- return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols)
53
-
54
- for filename in [
55
- "train-images-idx3-ubyte.gz",
56
- "train-labels-idx1-ubyte.gz",
57
- "t10k-images-idx3-ubyte.gz",
58
- "t10k-labels-idx1-ubyte.gz",
59
- ]:
60
- _download(base_url + filename, filename)
61
-
62
- train_images = parse_images(os.path.join(_DATA, "train-images-idx3-ubyte.gz"))
63
- train_labels = parse_labels(os.path.join(_DATA, "train-labels-idx1-ubyte.gz"))
64
- test_images = parse_images(os.path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
65
- test_labels = parse_labels(os.path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))
66
-
67
- return train_images, train_labels, test_images, test_labels
68
-
69
-
70
- def mnist(permute_train: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
71
- train_images, train_labels, test_images, test_labels = mnist_raw()
72
-
73
- train_images = _partial_flatten(train_images) / np.float32(255.0)
74
- test_images = _partial_flatten(test_images) / np.float32(255.0)
75
- train_labels = _one_hot(train_labels, 10)
76
- test_labels = _one_hot(test_labels, 10)
77
-
78
- if permute_train:
79
- perm = np.random.RandomState(0).permutation(train_images.shape[0])
80
- train_images = train_images[perm]
81
- train_labels = train_labels[perm]
82
-
83
- return train_images, train_labels, test_images, test_labels
84
-
85
-
86
- def loss(params: tuple[ArrayLike, ArrayLike], batch: tuple[ArrayLike, ArrayLike]) -> ArrayLike:
87
- inputs, targets = batch
88
- preds = predict(params, inputs)
89
- return -jnp.mean(jnp.sum(preds * targets, axis=1))
90
-
91
-
92
- def accuracy(params: tuple[ArrayLike, ArrayLike], batch: tuple[ArrayLike, ArrayLike]) -> ArrayLike:
93
- inputs, targets = batch
94
- target_class = jnp.argmax(targets, axis=1)
95
- predicted_class = jnp.argmax(predict(params, inputs), axis=1)
96
- return jnp.mean(predicted_class == target_class)
97
-
98
-
99
- init_random_params, predict = stax.serial(Dense(1024), Relu, Dense(1024), Relu, Dense(10), LogSoftmax)
100
-
101
- if __name__ == "__main__":
102
- # python -m examples.mnist
103
- rng = random.PRNGKey(0)
104
-
105
- step_size = 0.001
106
- num_epochs = 10
107
- batch_size = 128
108
- momentum_mass = 0.9
109
-
110
- train_images, train_labels, test_images, test_labels = mnist()
111
- num_train = train_images.shape[0]
112
- num_complete_batches, leftover = divmod(num_train, batch_size)
113
- num_batches = num_complete_batches + bool(leftover)
114
-
115
- def data_stream() -> Iterator[tuple[np.ndarray, np.ndarray]]:
116
- rng = npr.RandomState(0)
117
- while True:
118
- perm = rng.permutation(num_train)
119
- for i in range(num_batches):
120
- batch_idx = perm[i * batch_size : (i + 1) * batch_size]
121
- yield train_images[batch_idx], train_labels[batch_idx]
122
-
123
- batches = data_stream()
124
-
125
- opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
126
-
127
- @jit
128
- def update(i: int, opt_state: OptimizerState, batch: tuple[ArrayLike, ArrayLike]) -> OptimizerState:
129
- params = get_params(opt_state)
130
- return opt_update(i, grad(loss)(params, batch), opt_state)
131
-
132
- _, init_params = init_random_params(rng, (-1, 28 * 28))
133
- opt_state = opt_init(init_params)
134
- itercount = itertools.count()
135
-
136
- print("\nStarting training...")
137
- for epoch in range(num_epochs):
138
- start_time = time.time()
139
- for _ in range(num_batches):
140
- opt_state = update(next(itercount), opt_state, next(batches))
141
- epoch_time = time.time() - start_time
142
-
143
- params = get_params(opt_state)
144
- train_acc = accuracy(params, (train_images, train_labels))
145
- test_acc = accuracy(params, (test_images, test_labels))
146
- print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
147
- print(f"Training set accuracy {train_acc}")
148
- print(f"Test set accuracy {test_acc}")
@@ -1,21 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: xax
3
- Version: 0.0.1
4
- Summary: The xax project
5
- Home-page: https://github.com/dpshai/xax
6
- Author: Benjamin Bolte
7
- Requires-Python: >=3.11
8
- Description-Content-Type: text/markdown
9
- License-File: LICENSE
10
- Requires-Dist: jax
11
- Requires-Dist: jaxtyping
12
- Provides-Extra: dev
13
- Requires-Dist: black ; extra == 'dev'
14
- Requires-Dist: darglint ; extra == 'dev'
15
- Requires-Dist: mypy ; extra == 'dev'
16
- Requires-Dist: pytest ; extra == 'dev'
17
- Requires-Dist: ruff ; extra == 'dev'
18
-
19
- # xax
20
-
21
- JAX library for fast experimentation.
@@ -1,9 +0,0 @@
1
- examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- examples/mnist.py,sha256=vIKKlxS153hWM0Z8xM3HzbbF1leWrgRBPzL_fDoTMko,5366
3
- xax/__init__.py,sha256=sXLh7g3KC4QCFxcZGBTpG2scR7hmmBsMjq6LqRptkRg,22
4
- xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- xax-0.0.1.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
6
- xax-0.0.1.dist-info/METADATA,sha256=F9reokda5s_h_HTX5dKfUKRxNKLoAWNs5nnIR-g9kj0,524
7
- xax-0.0.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
8
- xax-0.0.1.dist-info/top_level.txt,sha256=rD77ScL4NoinNhAT1cDqumVhFa2AoqoIrgBKbwSsOzY,13
9
- xax-0.0.1.dist-info/RECORD,,
@@ -1,2 +0,0 @@
1
- examples
2
- xax
File without changes
File without changes