xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/utils/jax.py ADDED
@@ -0,0 +1,14 @@
1
+ """Defines some utility functions for interfacing with Jax."""
2
+
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+
6
+ Number = int | float | np.ndarray | jnp.ndarray
7
+
8
+
9
+ def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
10
+ if isinstance(value, (int, float)):
11
+ return float(value)
12
+ if isinstance(value, (np.ndarray, jnp.ndarray)):
13
+ return float(value.item())
14
+ raise TypeError(f"Unexpected type: {type(value)}")
xax/utils/logging.py ADDED
@@ -0,0 +1,223 @@
1
+ """Logging utilities."""
2
+
3
+ import logging
4
+ import math
5
+ import socket
6
+ import sys
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+ from xax.core.conf import load_user_config
11
+ from xax.utils.text import Color, color_parts, colored
12
+
13
+ # Logging level to show on all ranks.
14
+ LOG_INFO_ALL: int = logging.INFO + 1
15
+ LOG_DEBUG_ALL: int = logging.DEBUG + 1
16
+
17
+ # Show as a transient message.
18
+ LOG_PING: int = logging.INFO + 2
19
+
20
+ # Show as a persistent status message.
21
+ LOG_STATUS: int = logging.INFO + 3
22
+
23
+ # Reserved for error summary.
24
+ LOG_ERROR_SUMMARY: int = logging.INFO + 4
25
+
26
+
27
+ class RankFilter(logging.Filter):
28
+ def __init__(self, *, rank: int | None = None) -> None:
29
+ """Logging filter which filters out INFO logs on non-zero ranks.
30
+
31
+ Args:
32
+ rank: The current rank
33
+ """
34
+ super().__init__()
35
+
36
+ self.rank = rank
37
+
38
+ # Log using INFOALL to show on all ranks.
39
+ logging.addLevelName(LOG_INFO_ALL, "INFOALL")
40
+ logging.addLevelName(LOG_DEBUG_ALL, "DEBUGALL")
41
+ logging.addLevelName(LOG_PING, "PING")
42
+ logging.addLevelName(LOG_STATUS, "STATUS")
43
+ logging.addLevelName(LOG_ERROR_SUMMARY, "ERROR_SUMMARY")
44
+
45
+ self.log_all_ranks = {
46
+ logging.getLevelName(level)
47
+ for level in (
48
+ LOG_DEBUG_ALL,
49
+ LOG_INFO_ALL,
50
+ LOG_STATUS,
51
+ logging.CRITICAL,
52
+ logging.ERROR,
53
+ logging.WARNING,
54
+ )
55
+ }
56
+
57
+ self.log_no_ranks = {logging.getLevelName(level) for level in (LOG_ERROR_SUMMARY,)}
58
+
59
+ def filter(self, record: logging.LogRecord) -> bool:
60
+ if record.levelname in self.log_no_ranks:
61
+ return False
62
+ if self.rank is None or self.rank == 0:
63
+ return True
64
+ if record.levelname in self.log_all_ranks:
65
+ return True
66
+ return False
67
+
68
+
69
+ class ColoredFormatter(logging.Formatter):
70
+ """Defines a custom formatter for displaying logs."""
71
+
72
+ RESET_SEQ = "\033[0m"
73
+ COLOR_SEQ = "\033[1;%dm"
74
+ BOLD_SEQ = "\033[1m"
75
+
76
+ COLORS: dict[str, Color] = {
77
+ "WARNING": "yellow",
78
+ "INFOALL": "magenta",
79
+ "INFO": "cyan",
80
+ "DEBUGALL": "grey",
81
+ "DEBUG": "grey",
82
+ "CRITICAL": "yellow",
83
+ "FATAL": "red",
84
+ "ERROR": "red",
85
+ "STATUS": "green",
86
+ "PING": "magenta",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ *,
92
+ prefix: str | None = None,
93
+ rank: int | None = None,
94
+ world_size: int | None = None,
95
+ use_color: bool = True,
96
+ ) -> None:
97
+ asc_start, asc_end = color_parts("grey")
98
+ name_start, name_end = color_parts("blue", bold=True)
99
+
100
+ message_pre = [
101
+ "{levelname:^19s}",
102
+ asc_start,
103
+ "{asctime}",
104
+ asc_end,
105
+ " [",
106
+ name_start,
107
+ "{name}",
108
+ name_end,
109
+ "]",
110
+ ]
111
+ message_post = [" {message}"]
112
+
113
+ if prefix is not None:
114
+ message_pre += [" ", colored(prefix, "magenta", bold=True)]
115
+
116
+ if rank is not None or world_size is not None:
117
+ assert rank is not None and world_size is not None
118
+ digits = int(math.log10(world_size) + 1)
119
+ message_pre += [f" [{rank:0{digits}d}/{world_size}]"]
120
+ message = "".join(message_pre + message_post)
121
+
122
+ super().__init__(message, style="{", datefmt="%Y-%m-%d %H:%M:%S")
123
+
124
+ self.rank = rank
125
+ self.use_color = use_color
126
+
127
+ def format(self, record: logging.LogRecord) -> str:
128
+ levelname = record.levelname
129
+
130
+ match levelname:
131
+ case "DEBUG":
132
+ record.levelname = ""
133
+ case "INFOALL":
134
+ record.levelname = "INFO"
135
+ case "DEBUGALL":
136
+ record.levelname = "DEBUG"
137
+
138
+ if record.levelname and self.use_color and levelname in self.COLORS:
139
+ record.levelname = colored(record.levelname, self.COLORS[levelname], bold=True)
140
+ return logging.Formatter.format(self, record)
141
+
142
+
143
+ def configure_logging(prefix: str | None = None, *, rank: int | None = None, world_size: int | None = None) -> None:
144
+ """Instantiates logging.
145
+
146
+ This captures logs and reroutes them to the Toasts module, which is
147
+ pretty similar to Python logging except that the API is a lot easier to
148
+ interact with.
149
+
150
+ Args:
151
+ prefix: An optional prefix to add to the logger
152
+ rank: The current rank, or None if not using multiprocessing
153
+ world_size: The total world size, or None if not using multiprocessing
154
+ """
155
+ if rank is not None or world_size is not None:
156
+ assert rank is not None and world_size is not None
157
+ root_logger = logging.getLogger()
158
+
159
+ config = load_user_config().logging
160
+
161
+ # Captures warnings from the warnings module.
162
+ logging.captureWarnings(True)
163
+
164
+ filter = RankFilter(rank=rank)
165
+
166
+ stream_handler = logging.StreamHandler(sys.stdout)
167
+ stream_handler.setFormatter(ColoredFormatter(prefix=prefix, rank=rank, world_size=world_size))
168
+ stream_handler.addFilter(filter)
169
+ root_logger.addHandler(stream_handler)
170
+
171
+ root_logger.setLevel(logging._nameToLevel[config.log_level])
172
+
173
+ # Avoid junk logs from other libraries.
174
+ if config.hide_third_party_logs:
175
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
176
+ logging.getLogger("PIL").setLevel(logging.WARNING)
177
+ logging.getLogger("torch").setLevel(logging.WARNING)
178
+
179
+
180
+ def get_unused_port(default: int | None = None) -> int:
181
+ """Returns an unused port number on the local machine.
182
+
183
+ Args:
184
+ default: A default port to try before trying other ports.
185
+
186
+ Returns:
187
+ A port number which is currently unused
188
+ """
189
+ if default is not None:
190
+ sock = socket.socket()
191
+ try:
192
+ sock.bind(("", default))
193
+ return default
194
+ except OSError:
195
+ pass
196
+ finally:
197
+ sock.close()
198
+
199
+ sock = socket.socket()
200
+ sock.bind(("", 0))
201
+ return sock.getsockname()[1]
202
+
203
+
204
+ OmegaConf.register_new_resolver("mlfab.unused_port", get_unused_port, replace=True)
205
+
206
+
207
+ def port_is_busy(port: int) -> int:
208
+ """Checks whether a port is busy.
209
+
210
+ Args:
211
+ port: The port to check.
212
+
213
+ Returns:
214
+ Whether the port is busy.
215
+ """
216
+ sock = socket.socket()
217
+ try:
218
+ sock.bind(("", port))
219
+ return False
220
+ except OSError:
221
+ return True
222
+ finally:
223
+ sock.close()
xax/utils/numpy.py ADDED
@@ -0,0 +1,47 @@
1
+ """Defines some Numpy utility functions."""
2
+
3
+ import numpy as np
4
+
5
+
6
+ def partial_flatten(x: np.ndarray) -> np.ndarray:
7
+ """Flattens all but the first dimension of an array.
8
+
9
+ Args:
10
+ x: The array to flatten.
11
+
12
+ Returns:
13
+ The flattened array.
14
+ """
15
+ return np.reshape(x, (x.shape[0], -1))
16
+
17
+
18
+ def one_hot(x: np.ndarray, k: int, dtype: type = np.float32) -> np.ndarray:
19
+ """Converts an array of labels to a one-hot representation.
20
+
21
+ Args:
22
+ x: The array of labels.
23
+ k: The number of classes.
24
+ dtype: The dtype of the returned array.
25
+
26
+ Returns:
27
+ The one-hot representation of the labels.
28
+ """
29
+ return np.array(x[:, None] == np.arange(k), dtype)
30
+
31
+
32
+ def worker_chunk(x: np.ndarray, worker_id: int, num_workers: int, dim: int = 0) -> np.ndarray:
33
+ """Chunks an array into `num_workers` chunks.
34
+
35
+ Args:
36
+ x: The array to chunk.
37
+ worker_id: The worker ID.
38
+ num_workers: The number of workers.
39
+ dim: The dimension to chunk along.
40
+
41
+ Returns:
42
+ The chunked array.
43
+ """
44
+ chunk_size = x.shape[dim] // num_workers
45
+ start = worker_id * chunk_size
46
+ end = start + chunk_size
47
+ return x[start:end]
@@ -0,0 +1,258 @@
1
+ """Defines utility functions for interfacing with Tensorboard."""
2
+
3
+ import functools
4
+ import io
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Literal, TypedDict
8
+
9
+ from PIL.Image import Image as PILImage
10
+ from tensorboard.compat.proto.config_pb2 import RunMetadata
11
+ from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
12
+ from tensorboard.compat.proto.graph_pb2 import GraphDef
13
+ from tensorboard.compat.proto.summary_pb2 import Summary, SummaryMetadata
14
+ from tensorboard.compat.proto.tensor_pb2 import TensorProto
15
+ from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
16
+ from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
17
+ from tensorboard.summary.writer.event_file_writer import EventFileWriter
18
+
19
+ from xax.core.state import Phase
20
+
21
+ ImageShape = Literal["HWC", "CHW", "HW", "NHWC", "NCHW", "NHW"]
22
+
23
+
24
+ class TensorboardProtobufWriter:
25
+ def __init__(
26
+ self,
27
+ log_directory: str | Path,
28
+ max_queue_size: int = 10,
29
+ flush_seconds: float = 120.0,
30
+ filename_suffix: str = "",
31
+ ) -> None:
32
+ super().__init__()
33
+
34
+ self.log_directory = Path(log_directory)
35
+ self.max_queue_size = max_queue_size
36
+ self.flush_seconds = flush_seconds
37
+ self.filename_suffix = filename_suffix
38
+
39
+ @functools.cached_property
40
+ def event_writer(self) -> EventFileWriter:
41
+ return EventFileWriter(
42
+ logdir=str(self.log_directory),
43
+ max_queue_size=self.max_queue_size,
44
+ flush_secs=self.flush_seconds,
45
+ filename_suffix=self.filename_suffix,
46
+ )
47
+
48
+ def add_event(
49
+ self,
50
+ event: Event,
51
+ step: int | None = None,
52
+ walltime: float | None = None,
53
+ ) -> None:
54
+ event.wall_time = time.time() if walltime is None else walltime
55
+ if step is not None:
56
+ event.step = int(step)
57
+ self.event_writer.add_event(event)
58
+
59
+ def add_summary(
60
+ self,
61
+ summary: Summary,
62
+ global_step: int | None = None,
63
+ walltime: float | None = None,
64
+ ) -> None:
65
+ event = Event(summary=summary)
66
+ self.add_event(event, step=global_step, walltime=walltime)
67
+
68
+ def add_graph(
69
+ self,
70
+ graph: GraphDef,
71
+ run_metadata: RunMetadata | None = None,
72
+ walltime: float | None = None,
73
+ ) -> None:
74
+ event = Event(graph_def=graph.SerializeToString())
75
+ self.add_event(event, walltime=walltime)
76
+ if run_metadata is not None:
77
+ trm = TaggedRunMetadata(tag="step1", run_metadata=run_metadata.SerializeToString())
78
+ event = Event(tagged_run_metadata=trm)
79
+ self.add_event(event, walltime=walltime)
80
+
81
+ def flush(self) -> None:
82
+ self.event_writer.flush()
83
+
84
+ def close(self) -> None:
85
+ self.event_writer.close()
86
+
87
+
88
+ class TensorboardWriter:
89
+ """Defines a class for writing artifacts to Tensorboard.
90
+
91
+ Parameters:
92
+ log_directory: The directory to write logs to.
93
+ max_queue_size: The maximum queue size.
94
+ flush_seconds: How often to flush logs.
95
+ filename_suffix: The filename suffix to use.
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ log_directory: str | Path,
101
+ max_queue_size: int = 10,
102
+ flush_seconds: float = 120.0,
103
+ filename_suffix: str = "",
104
+ ) -> None:
105
+ super().__init__()
106
+
107
+ self.pb_writer = TensorboardProtobufWriter(
108
+ log_directory=log_directory,
109
+ max_queue_size=max_queue_size,
110
+ flush_seconds=flush_seconds,
111
+ filename_suffix=filename_suffix,
112
+ )
113
+
114
+ def add_scalar(
115
+ self,
116
+ tag: str,
117
+ value: float,
118
+ global_step: int | None = None,
119
+ walltime: float | None = None,
120
+ new_style: bool = True,
121
+ double_precision: bool = False,
122
+ ) -> None:
123
+ if new_style:
124
+ self.pb_writer.add_summary(
125
+ Summary(
126
+ value=[
127
+ Summary.Value(
128
+ tag=tag,
129
+ tensor=(
130
+ TensorProto(double_val=[value], dtype="DT_DOUBLE")
131
+ if double_precision
132
+ else TensorProto(float_val=[value], dtype="DT_FLOAT")
133
+ ),
134
+ metadata=SummaryMetadata(
135
+ plugin_data=SummaryMetadata.PluginData(
136
+ plugin_name="scalars",
137
+ ),
138
+ ),
139
+ )
140
+ ],
141
+ ),
142
+ global_step=global_step,
143
+ walltime=walltime,
144
+ )
145
+ else:
146
+ self.pb_writer.add_summary(
147
+ Summary(
148
+ value=[
149
+ Summary.Value(
150
+ tag=tag,
151
+ simple_value=value,
152
+ ),
153
+ ],
154
+ ),
155
+ global_step=global_step,
156
+ walltime=walltime,
157
+ )
158
+
159
+ def add_image(
160
+ self,
161
+ tag: str,
162
+ value: PILImage,
163
+ global_step: int | None = None,
164
+ walltime: float | None = None,
165
+ ) -> None:
166
+ output = io.BytesIO()
167
+ value.convert("RGB").save(output, format="PNG")
168
+ image_string = output.getvalue()
169
+ output.close()
170
+
171
+ self.pb_writer.add_summary(
172
+ Summary(
173
+ value=[
174
+ Summary.Value(
175
+ tag=tag,
176
+ image=Summary.Image(
177
+ height=value.height,
178
+ width=value.width,
179
+ colorspace=3, # RGB
180
+ encoded_image_string=image_string,
181
+ ),
182
+ ),
183
+ ],
184
+ ),
185
+ global_step=global_step,
186
+ walltime=walltime,
187
+ )
188
+
189
+ def add_text(
190
+ self,
191
+ tag: str,
192
+ value: str,
193
+ global_step: int | None = None,
194
+ walltime: float | None = None,
195
+ ) -> None:
196
+ self.pb_writer.add_summary(
197
+ Summary(
198
+ value=[
199
+ Summary.Value(
200
+ tag=tag + "/text_summary",
201
+ metadata=SummaryMetadata(
202
+ plugin_data=SummaryMetadata.PluginData(
203
+ plugin_name="text", content=TextPluginData(version=0).SerializeToString()
204
+ ),
205
+ ),
206
+ tensor=TensorProto(
207
+ dtype="DT_STRING",
208
+ string_val=[value.encode(encoding="utf_8")],
209
+ tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
210
+ ),
211
+ ),
212
+ ],
213
+ ),
214
+ global_step=global_step,
215
+ walltime=walltime,
216
+ )
217
+
218
+
219
+ class TensorboardWriterKwargs(TypedDict):
220
+ max_queue_size: int
221
+ flush_seconds: float
222
+ filename_suffix: str
223
+
224
+
225
+ class TensorboardWriters:
226
+ def __init__(
227
+ self,
228
+ log_directory: str | Path,
229
+ max_queue_size: int = 10,
230
+ flush_seconds: float = 120.0,
231
+ filename_suffix: str = "",
232
+ ) -> None:
233
+ super().__init__()
234
+
235
+ self.log_directory = Path(log_directory)
236
+
237
+ self.kwargs: TensorboardWriterKwargs = {
238
+ "max_queue_size": max_queue_size,
239
+ "flush_seconds": flush_seconds,
240
+ "filename_suffix": filename_suffix,
241
+ }
242
+
243
+ @functools.cached_property
244
+ def train_writer(self) -> TensorboardWriter:
245
+ return TensorboardWriter(self.log_directory / "train", **self.kwargs)
246
+
247
+ @functools.cached_property
248
+ def valid_writer(self) -> TensorboardWriter:
249
+ return TensorboardWriter(self.log_directory / "valid", **self.kwargs)
250
+
251
+ def writer(self, phase: Phase) -> TensorboardWriter:
252
+ match phase:
253
+ case "train":
254
+ return self.train_writer
255
+ case "valid":
256
+ return self.valid_writer
257
+ case _:
258
+ raise NotImplementedError(f"Unexpected phase: {phase}")