xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/task/logger.py
ADDED
@@ -0,0 +1,783 @@
|
|
1
|
+
"""Defines the core logger.
|
2
|
+
|
3
|
+
A common problem when quickly prototyping ML models is nicely logging images,
|
4
|
+
videos, audio, or other data. Additionally, logging on every step can be
|
5
|
+
overwhelming. This logger implements a number of convenience functions to
|
6
|
+
take heterogeneous input data and put it into a standard format, which can
|
7
|
+
then be used by downstream loggers to actually log the data. For example, this
|
8
|
+
logger will automatically tile multiple images into a single image, add
|
9
|
+
captions to images, and so on.
|
10
|
+
"""
|
11
|
+
|
12
|
+
import functools
|
13
|
+
import logging
|
14
|
+
import math
|
15
|
+
import re
|
16
|
+
import time
|
17
|
+
from abc import ABC, abstractmethod
|
18
|
+
from collections import defaultdict
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from types import TracebackType
|
21
|
+
from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
|
22
|
+
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
25
|
+
import numpy as np
|
26
|
+
from jaxtyping import Array
|
27
|
+
from omegaconf import DictConfig
|
28
|
+
from PIL import Image, ImageDraw, ImageFont
|
29
|
+
from PIL.Image import Image as PILImage
|
30
|
+
|
31
|
+
from xax.core.state import Phase, State
|
32
|
+
from xax.utils.experiments import IntervalTicker
|
33
|
+
from xax.utils.logging import LOG_ERROR_SUMMARY, LOG_PING, LOG_STATUS
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
T = TypeVar("T")
|
38
|
+
LogT = TypeVar("LogT")
|
39
|
+
Number = int | float | Array | np.ndarray
|
40
|
+
|
41
|
+
ChannelSelectMode = Literal["first", "last", "mean"]
|
42
|
+
|
43
|
+
DEFAULT_NAMESPACE = "value"
|
44
|
+
|
45
|
+
NAMESPACE_STACK: list[str] = []
|
46
|
+
|
47
|
+
|
48
|
+
def standardize_text(text: str, max_line_length: int | None = None, remove_non_ascii: bool = False) -> list[str]:
|
49
|
+
"""Standardizes a text string to a list of lines.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
text: The text to standardize
|
53
|
+
max_line_length: If set, truncate lines to this length
|
54
|
+
remove_non_ascii: Remove non-ASCII characters if present
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
The standardized text lines
|
58
|
+
"""
|
59
|
+
|
60
|
+
def _chunk_lines(text: str, max_length: int) -> Iterator[str]:
|
61
|
+
for i in range(0, len(text), max_length):
|
62
|
+
yield text[i : i + max_length]
|
63
|
+
|
64
|
+
if remove_non_ascii:
|
65
|
+
text = "".join(char for char in text if ord(char) < 128)
|
66
|
+
lines = [re.sub(r"\s+", " ", line) for line in re.split(r"[\n\r]+", text.strip())]
|
67
|
+
if max_line_length is not None:
|
68
|
+
lines = [subline for line in lines for subline in _chunk_lines(line, max_line_length)]
|
69
|
+
return lines
|
70
|
+
|
71
|
+
|
72
|
+
def make_human_viewable_resolution(
|
73
|
+
image: PILImage,
|
74
|
+
interpolation: Image.Resampling = Image.Resampling.LANCZOS,
|
75
|
+
trg_res: tuple[int, int] = (512, 512),
|
76
|
+
) -> PILImage:
|
77
|
+
"""Resizes image to human-viewable resolution.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
image: The image to resize, with shape (C, H, W)
|
81
|
+
interpolation: Interpolation mode to use for image resizing
|
82
|
+
trg_res: The target image resolution; the image will be reshaped to
|
83
|
+
have approximately the same area as an image with this resolution
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
The resized image
|
87
|
+
"""
|
88
|
+
width, height = image.size
|
89
|
+
trg_height, trg_width = trg_res
|
90
|
+
factor = math.sqrt((trg_height * trg_width) / (height * width))
|
91
|
+
new_height, new_width = int(height * factor), int(width * factor)
|
92
|
+
return image.resize((new_width, new_height), interpolation)
|
93
|
+
|
94
|
+
|
95
|
+
def image_with_text(
|
96
|
+
image: PILImage,
|
97
|
+
text: list[str],
|
98
|
+
max_num_lines: int | None,
|
99
|
+
line_spacing: int,
|
100
|
+
centered: bool,
|
101
|
+
) -> PILImage:
|
102
|
+
"""Adds a text label to an image.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
image: The image to label, with shape (C, H, W)
|
106
|
+
text: The text label for the image
|
107
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
108
|
+
of the image
|
109
|
+
line_spacing: The spacing between adjacent lines
|
110
|
+
centered: If set, center the text labels, otherwise align to the left
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
The image with a text label
|
114
|
+
"""
|
115
|
+
if not text:
|
116
|
+
return image
|
117
|
+
if max_num_lines is None:
|
118
|
+
max_num_lines = len(text)
|
119
|
+
else:
|
120
|
+
text = text[:max_num_lines]
|
121
|
+
width, height = image.size
|
122
|
+
font: ImageFont.ImageFont = ImageFont.load_default()
|
123
|
+
_, _, _, line_height = font.getbbox(text[0])
|
124
|
+
new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
|
125
|
+
padded_image = Image.new(image.mode, (new_width, new_height), 255)
|
126
|
+
padded_image.paste(image, (0, 0))
|
127
|
+
drawer = ImageDraw.Draw(padded_image)
|
128
|
+
for i, text_line in enumerate(text):
|
129
|
+
text_line_top = height + line_spacing + i * (line_height + line_spacing)
|
130
|
+
if centered:
|
131
|
+
_, _, line_width, _ = font.getbbox(text_line)
|
132
|
+
text_line_left = (width - line_width) / 2
|
133
|
+
drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
|
134
|
+
else:
|
135
|
+
drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
|
136
|
+
return padded_image
|
137
|
+
|
138
|
+
|
139
|
+
class namespace_context: # noqa: N801
|
140
|
+
def __init__(self, name: str | None) -> None:
|
141
|
+
self._name = name
|
142
|
+
self._prev_stack: list[str] | None = None
|
143
|
+
|
144
|
+
def __enter__(self) -> None:
|
145
|
+
if self._name is None:
|
146
|
+
self._prev_stack = NAMESPACE_STACK[:]
|
147
|
+
NAMESPACE_STACK.clear()
|
148
|
+
else:
|
149
|
+
NAMESPACE_STACK.append(self._name)
|
150
|
+
|
151
|
+
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
|
152
|
+
if self._prev_stack is not None:
|
153
|
+
NAMESPACE_STACK[:] = self._prev_stack
|
154
|
+
else:
|
155
|
+
NAMESPACE_STACK.pop()
|
156
|
+
|
157
|
+
|
158
|
+
def normalize(x: np.ndarray) -> np.ndarray:
|
159
|
+
return (x - x.min()) / (x.max() - x.min())
|
160
|
+
|
161
|
+
|
162
|
+
def ternary_search_optimal_side_counts(height: int, width: int, count: int) -> tuple[int, int]:
|
163
|
+
min_factors = [i for i in range(1, math.ceil(math.sqrt(count)) + 1) if count % i == 0]
|
164
|
+
max_factors = [i for i in min_factors[::-1] if i * i != count]
|
165
|
+
factors = [(i, count // i) for i in min_factors] + [(count // i, i) for i in max_factors]
|
166
|
+
|
167
|
+
lo, hi = 0, len(factors) - 1
|
168
|
+
|
169
|
+
def penalty(i: int) -> float:
|
170
|
+
hval, wval = factors[i]
|
171
|
+
h, w = hval * height, wval * width
|
172
|
+
return -(min(h, w) ** 2)
|
173
|
+
|
174
|
+
# Runs ternary search to minimize penalty.
|
175
|
+
while lo < hi - 2:
|
176
|
+
lmid, rmid = (lo * 2 + hi) // 3, (lo + hi * 2) // 3
|
177
|
+
if penalty(lmid) > penalty(rmid):
|
178
|
+
lo = lmid
|
179
|
+
else:
|
180
|
+
hi = rmid
|
181
|
+
|
182
|
+
# Returns the lowest-penalty configuration.
|
183
|
+
mid = (lo + hi) // 2
|
184
|
+
plo, pmid, phi = penalty(lo), penalty(mid), penalty(hi)
|
185
|
+
|
186
|
+
if pmid <= plo and pmid <= phi:
|
187
|
+
return factors[mid]
|
188
|
+
elif plo <= phi:
|
189
|
+
return factors[lo]
|
190
|
+
else:
|
191
|
+
return factors[hi]
|
192
|
+
|
193
|
+
|
194
|
+
def tile_images_different_sizes(images: list[PILImage], sep: int) -> PILImage:
|
195
|
+
"""Tiles a list of images into a single image, even if they have different sizes.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
images: The images to tile.
|
199
|
+
sep: The separation between adjacent images.
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
The tiled image.
|
203
|
+
"""
|
204
|
+
total_width, max_height = sum(image.width for image in images), max(image.height for image in images)
|
205
|
+
tiled = Image.new("RGB", (total_width + (len(images) - 1) * sep, max_height))
|
206
|
+
x = 0
|
207
|
+
for image in images:
|
208
|
+
tiled.paste(image, (x, 0))
|
209
|
+
x += image.width + sep
|
210
|
+
return tiled
|
211
|
+
|
212
|
+
|
213
|
+
def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
|
214
|
+
"""Tiles a list of images into a single image.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
images: The images to tile.
|
218
|
+
sep: The separation between adjacent images.
|
219
|
+
|
220
|
+
Returns:
|
221
|
+
The tiled image.
|
222
|
+
"""
|
223
|
+
if not images:
|
224
|
+
return Image.new("RGB", (0, 0))
|
225
|
+
|
226
|
+
# Gets the optimal side counts.
|
227
|
+
height, width = images[0].height, images[0].width
|
228
|
+
if not all(image.size == images[0].size for image in images):
|
229
|
+
return tile_images_different_sizes(images, sep)
|
230
|
+
|
231
|
+
hside, wside = ternary_search_optimal_side_counts(height, width, len(images))
|
232
|
+
|
233
|
+
# Tiles the images.
|
234
|
+
tiled = Image.new("RGB", (wside * width + (wside - 1) * sep, hside * height + (hside - 1) * sep))
|
235
|
+
for i, image in enumerate(images):
|
236
|
+
x, y = i % wside, i // wside
|
237
|
+
tiled.paste(image, (x * (width + sep), y * (height + sep)))
|
238
|
+
|
239
|
+
return tiled
|
240
|
+
|
241
|
+
|
242
|
+
def as_numpy(array: Array) -> np.ndarray:
|
243
|
+
array = jax.device_get(array)
|
244
|
+
if jax.dtypes.issubdtype(array.dtype, jnp.floating):
|
245
|
+
array = array.astype(jnp.float32)
|
246
|
+
elif jax.dtypes.issubdtype(array.dtype, jnp.integer):
|
247
|
+
array = array.astype(jnp.int32)
|
248
|
+
elif jax.dtypes.issubdtype(array.dtype, jnp.bool_):
|
249
|
+
array = array.astype(jnp.bool_)
|
250
|
+
return np.array(array)
|
251
|
+
|
252
|
+
|
253
|
+
def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int, int] | None = None) -> PILImage:
|
254
|
+
if not isinstance(image, (np.ndarray, Array, PILImage)):
|
255
|
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
256
|
+
if isinstance(image, Array):
|
257
|
+
image = as_numpy(image)
|
258
|
+
if isinstance(image, np.ndarray):
|
259
|
+
if image.ndim == 2:
|
260
|
+
image = np.expand_dims(image, axis=-1)
|
261
|
+
if image.ndim != 3:
|
262
|
+
raise RuntimeError(f"Expected image to have shape HW, HWC, or CHW, got {image.shape}")
|
263
|
+
|
264
|
+
# Normalizes the image and converts to integer.
|
265
|
+
if np.issubdtype(image.dtype, np.floating):
|
266
|
+
image = (normalize(image) * 255).round().astype(np.uint8)
|
267
|
+
elif image.dtype == np.uint8:
|
268
|
+
pass
|
269
|
+
else:
|
270
|
+
raise ValueError(f"Unsupported image dtype: {image.dtype}")
|
271
|
+
|
272
|
+
# Converts to a PIL image.
|
273
|
+
if image.shape[-1] == 1:
|
274
|
+
image = Image.fromarray(image[..., 0])
|
275
|
+
elif image.shape[-1] == 3:
|
276
|
+
image = Image.fromarray(image)
|
277
|
+
elif image.shape[0] == 1:
|
278
|
+
image = Image.fromarray(image[0])
|
279
|
+
elif image.shape[0] == 3:
|
280
|
+
image = Image.fromarray(image.transpose(1, 2, 0))
|
281
|
+
else:
|
282
|
+
raise ValueError(f"Unsupported image shape: {image.shape}")
|
283
|
+
|
284
|
+
if target_resolution is not None:
|
285
|
+
image = make_human_viewable_resolution(image, trg_res=target_resolution)
|
286
|
+
return image
|
287
|
+
|
288
|
+
|
289
|
+
@dataclass
|
290
|
+
class LogImage:
|
291
|
+
image: PILImage
|
292
|
+
|
293
|
+
|
294
|
+
@dataclass
|
295
|
+
class LogLine:
|
296
|
+
state: State
|
297
|
+
scalars: dict[str, dict[str, Number]]
|
298
|
+
strings: dict[str, dict[str, str]]
|
299
|
+
images: dict[str, dict[str, LogImage]]
|
300
|
+
|
301
|
+
|
302
|
+
@dataclass
|
303
|
+
class LogErrorSummary:
|
304
|
+
message: str
|
305
|
+
|
306
|
+
|
307
|
+
@dataclass
|
308
|
+
class LogError:
|
309
|
+
message: str
|
310
|
+
location: str | None = None
|
311
|
+
|
312
|
+
@property
|
313
|
+
def message_with_location(self) -> str:
|
314
|
+
message = self.message
|
315
|
+
if self.location is not None:
|
316
|
+
message += f" ({self.location})"
|
317
|
+
return message
|
318
|
+
|
319
|
+
|
320
|
+
@dataclass
|
321
|
+
class LogStatus:
|
322
|
+
message: str
|
323
|
+
created: float
|
324
|
+
filename: str | None = None
|
325
|
+
lineno: int | None = None
|
326
|
+
|
327
|
+
|
328
|
+
@dataclass
|
329
|
+
class LogPing:
|
330
|
+
message: str
|
331
|
+
created: float
|
332
|
+
filename: str | None = None
|
333
|
+
lineno: int | None = None
|
334
|
+
|
335
|
+
|
336
|
+
class LoggerImpl(ABC):
|
337
|
+
def __init__(self, log_interval_seconds: float = 1.0) -> None:
|
338
|
+
"""Defines some default behavior for loggers.
|
339
|
+
|
340
|
+
Every logger needs to implement the ``write`` function, which handles
|
341
|
+
actually writing the logs to wherever they needs to go. The basic
|
342
|
+
class implements a simple interval-based logging scheme to avoid
|
343
|
+
writing too many log lines.
|
344
|
+
|
345
|
+
Args:
|
346
|
+
log_interval_seconds: The interval between successive log lines.
|
347
|
+
"""
|
348
|
+
super().__init__()
|
349
|
+
|
350
|
+
self.tickers = {phase: IntervalTicker(log_interval_seconds) for phase in get_args(Phase)}
|
351
|
+
|
352
|
+
def start(self) -> None:
|
353
|
+
pass
|
354
|
+
|
355
|
+
def stop(self) -> None:
|
356
|
+
pass
|
357
|
+
|
358
|
+
@abstractmethod
|
359
|
+
def write(self, line: LogLine) -> None:
|
360
|
+
"""Handles writing the current log line.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
line: The line to write.
|
364
|
+
"""
|
365
|
+
|
366
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
367
|
+
"""Handles writing an error summary.
|
368
|
+
|
369
|
+
Args:
|
370
|
+
error_summary: The error summary to write.
|
371
|
+
"""
|
372
|
+
|
373
|
+
def write_error(self, error: LogError) -> None:
|
374
|
+
"""Handles writing an error line.
|
375
|
+
|
376
|
+
Args:
|
377
|
+
error: The error information to write.
|
378
|
+
"""
|
379
|
+
|
380
|
+
def write_status(self, status: LogStatus) -> None:
|
381
|
+
"""Handles writing a status line.
|
382
|
+
|
383
|
+
Args:
|
384
|
+
status: The status to write.
|
385
|
+
"""
|
386
|
+
|
387
|
+
def write_ping(self, ping: LogPing) -> None:
|
388
|
+
"""Handles writing a ping line.
|
389
|
+
|
390
|
+
Args:
|
391
|
+
ping: The ping to write.
|
392
|
+
"""
|
393
|
+
|
394
|
+
def log_git_state(self, git_state: str) -> None:
|
395
|
+
"""Logs Git state for the current run.
|
396
|
+
|
397
|
+
Args:
|
398
|
+
git_state: The Git state, as text blocks.
|
399
|
+
"""
|
400
|
+
|
401
|
+
def log_training_code(self, training_code: str) -> None:
|
402
|
+
"""Logs the training script code.
|
403
|
+
|
404
|
+
Args:
|
405
|
+
training_code: The training script code.
|
406
|
+
"""
|
407
|
+
|
408
|
+
def log_config(self, config: DictConfig) -> None:
|
409
|
+
"""Logs the configuration for the current run.
|
410
|
+
|
411
|
+
Args:
|
412
|
+
config: The configuration, as a DictConfig.
|
413
|
+
"""
|
414
|
+
|
415
|
+
def should_log(self, state: State) -> bool:
|
416
|
+
"""Function that determines if the logger should log the current step.
|
417
|
+
|
418
|
+
Args:
|
419
|
+
state: The current step's state.
|
420
|
+
|
421
|
+
Returns:
|
422
|
+
If the logger should log the current step.
|
423
|
+
"""
|
424
|
+
return self.tickers[state.phase].tick(state.elapsed_time_s)
|
425
|
+
|
426
|
+
|
427
|
+
class ToastHandler(logging.Handler):
|
428
|
+
def __init__(self, logger: "Logger") -> None:
|
429
|
+
super().__init__()
|
430
|
+
|
431
|
+
self.logger = logger
|
432
|
+
|
433
|
+
def emit(self, record: logging.LogRecord) -> None:
|
434
|
+
try:
|
435
|
+
if record.levelno == LOG_ERROR_SUMMARY:
|
436
|
+
self.logger.write_error_summary(record.getMessage())
|
437
|
+
elif record.levelno == LOG_STATUS:
|
438
|
+
self.logger.write_status(record.getMessage(), record.filename, record.lineno)
|
439
|
+
elif record.levelno in (LOG_PING, logging.WARNING):
|
440
|
+
self.logger.write_ping(record.getMessage(), record.filename, record.lineno)
|
441
|
+
elif record.levelno in (logging.ERROR, logging.CRITICAL, logging.WARNING):
|
442
|
+
self.logger.write_error(record.getMessage(), f"{record.filename}:{record.lineno}")
|
443
|
+
except RecursionError:
|
444
|
+
raise
|
445
|
+
except Exception:
|
446
|
+
self.handleError(record)
|
447
|
+
|
448
|
+
def add_for_logger(self, logger: logging.Logger) -> None:
|
449
|
+
# Removes existing ToastHandler.
|
450
|
+
handlers_to_remove = []
|
451
|
+
for handler in logger.handlers:
|
452
|
+
if isinstance(handler, ToastHandler):
|
453
|
+
handlers_to_remove.append(handler)
|
454
|
+
for handler in handlers_to_remove:
|
455
|
+
logger.removeHandler(handler)
|
456
|
+
|
457
|
+
# Adds the new ToastHandler.
|
458
|
+
logger.addHandler(self)
|
459
|
+
|
460
|
+
|
461
|
+
class Logger:
|
462
|
+
"""Defines an intermediate container which holds values to log somewhere else."""
|
463
|
+
|
464
|
+
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
465
|
+
self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
|
466
|
+
self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
|
467
|
+
self.images: dict[str, dict[str, Callable[[], PILImage]]] = defaultdict(dict)
|
468
|
+
self.default_namespace = default_namespace
|
469
|
+
self.loggers: list[LoggerImpl] = []
|
470
|
+
|
471
|
+
# Registers a logging handler to route log messages to the logger.
|
472
|
+
root_logger = logging.getLogger()
|
473
|
+
ToastHandler(self).add_for_logger(root_logger)
|
474
|
+
|
475
|
+
# Flag when the logger is active.
|
476
|
+
self.active = False
|
477
|
+
|
478
|
+
def add_logger(self, *logger: LoggerImpl) -> None:
|
479
|
+
"""Add the logger, so that it gets called when `write` is called.
|
480
|
+
|
481
|
+
Args:
|
482
|
+
logger: The logger to add.
|
483
|
+
"""
|
484
|
+
self.loggers.extend(logger)
|
485
|
+
|
486
|
+
def pack(self, state: State) -> LogLine:
|
487
|
+
return LogLine(
|
488
|
+
state=state,
|
489
|
+
scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
|
490
|
+
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
491
|
+
images={k: {kk: LogImage(v()) for kk, v in v.items()} for k, v in self.images.items()},
|
492
|
+
)
|
493
|
+
|
494
|
+
def clear(self) -> None:
|
495
|
+
self.scalars.clear()
|
496
|
+
self.strings.clear()
|
497
|
+
self.images.clear()
|
498
|
+
|
499
|
+
def write(self, state: State) -> None:
|
500
|
+
"""Writes the current step's logging information.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
state: The current step's state.
|
504
|
+
"""
|
505
|
+
should_log = [logger.should_log(state) for logger in self.loggers]
|
506
|
+
if not any(should_log):
|
507
|
+
self.clear()
|
508
|
+
return
|
509
|
+
line = self.pack(state)
|
510
|
+
self.clear()
|
511
|
+
for logger in (logger for logger, should_log in zip(self.loggers, should_log) if should_log):
|
512
|
+
logger.write(line)
|
513
|
+
|
514
|
+
def write_error_summary(self, error_summary: str) -> None:
|
515
|
+
for logger in self.loggers:
|
516
|
+
logger.write_error_summary(LogErrorSummary(error_summary))
|
517
|
+
|
518
|
+
def write_error(self, message: str, location: str | None = None) -> None:
|
519
|
+
for logger in self.loggers:
|
520
|
+
logger.write_error(LogError(message, location))
|
521
|
+
|
522
|
+
def write_status(
|
523
|
+
self,
|
524
|
+
message: str,
|
525
|
+
filename: str | None = None,
|
526
|
+
lineno: int | None = None,
|
527
|
+
created: float | None = None,
|
528
|
+
) -> None:
|
529
|
+
status = LogStatus(message, time.time() if created is None else created, filename, lineno)
|
530
|
+
for logger in self.loggers:
|
531
|
+
logger.write_status(status)
|
532
|
+
|
533
|
+
def write_ping(
|
534
|
+
self,
|
535
|
+
message: str,
|
536
|
+
filename: str | None = None,
|
537
|
+
lineno: int | None = None,
|
538
|
+
created: float | None = None,
|
539
|
+
) -> None:
|
540
|
+
ping = LogPing(message, time.time() if created is None else created, filename, lineno)
|
541
|
+
for logger in self.loggers:
|
542
|
+
logger.write_ping(ping)
|
543
|
+
|
544
|
+
def resolve_namespace(self, namespace: str | None = None) -> str:
|
545
|
+
return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
|
546
|
+
|
547
|
+
def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
|
548
|
+
"""Logs a scalar value.
|
549
|
+
|
550
|
+
Args:
|
551
|
+
key: The key being logged
|
552
|
+
value: The scalar value being logged
|
553
|
+
namespace: An optional logging namespace
|
554
|
+
"""
|
555
|
+
if not self.active:
|
556
|
+
raise RuntimeError("The logger is not active")
|
557
|
+
namespace = self.resolve_namespace(namespace)
|
558
|
+
|
559
|
+
@functools.lru_cache(maxsize=None)
|
560
|
+
def scalar_future() -> Number:
|
561
|
+
return value() if callable(value) else value
|
562
|
+
|
563
|
+
self.scalars[namespace][key] = scalar_future
|
564
|
+
|
565
|
+
def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
|
566
|
+
"""Logs a string value.
|
567
|
+
|
568
|
+
Args:
|
569
|
+
key: The key being logged
|
570
|
+
value: The string value being logged
|
571
|
+
namespace: An optional logging namespace
|
572
|
+
"""
|
573
|
+
if not self.active:
|
574
|
+
raise RuntimeError("The logger is not active")
|
575
|
+
namespace = self.resolve_namespace(namespace)
|
576
|
+
|
577
|
+
@functools.lru_cache(maxsize=None)
|
578
|
+
def value_future() -> str:
|
579
|
+
return value() if callable(value) else value
|
580
|
+
|
581
|
+
self.strings[namespace][key] = value_future
|
582
|
+
|
583
|
+
def log_image(
|
584
|
+
self,
|
585
|
+
key: str,
|
586
|
+
value: Callable[[], np.ndarray | Array | PILImage] | np.ndarray | Array | PILImage,
|
587
|
+
*,
|
588
|
+
namespace: str | None = None,
|
589
|
+
target_resolution: tuple[int, int] | None = (512, 512),
|
590
|
+
) -> None:
|
591
|
+
"""Logs an image.
|
592
|
+
|
593
|
+
Args:
|
594
|
+
key: The key being logged
|
595
|
+
value: The image being logged
|
596
|
+
namespace: An optional logging namespace
|
597
|
+
target_resolution: The target resolution for each image; if None,
|
598
|
+
don't resample the images
|
599
|
+
"""
|
600
|
+
if not self.active:
|
601
|
+
raise RuntimeError("The logger is not active")
|
602
|
+
namespace = self.resolve_namespace(namespace)
|
603
|
+
|
604
|
+
@functools.lru_cache(maxsize=None)
|
605
|
+
def image_future() -> PILImage:
|
606
|
+
return get_image(value() if callable(value) else value, target_resolution)
|
607
|
+
|
608
|
+
self.images[namespace][key] = image_future
|
609
|
+
|
610
|
+
def log_labeled_image(
|
611
|
+
self,
|
612
|
+
key: str,
|
613
|
+
value: Callable[[], tuple[np.ndarray | Array | PILImage, str]] | tuple[np.ndarray | Array | PILImage, str],
|
614
|
+
*,
|
615
|
+
namespace: str | None = None,
|
616
|
+
max_line_length: int | None = None,
|
617
|
+
max_num_lines: int | None = None,
|
618
|
+
target_resolution: tuple[int, int] | None = (512, 512),
|
619
|
+
line_spacing: int = 2,
|
620
|
+
centered: bool = True,
|
621
|
+
) -> None:
|
622
|
+
"""Logs an image with a label.
|
623
|
+
|
624
|
+
Args:
|
625
|
+
key: The key being logged
|
626
|
+
value: The image and label being logged
|
627
|
+
namespace: An optional logging namespace
|
628
|
+
max_line_length: The maximum line length for the label
|
629
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
630
|
+
of the image
|
631
|
+
target_resolution: The target resolution for each image; if None,
|
632
|
+
don't resample the images
|
633
|
+
line_spacing: The spacing between adjacent lines
|
634
|
+
centered: If set, center the text labels, otherwise align to the left
|
635
|
+
"""
|
636
|
+
if not self.active:
|
637
|
+
raise RuntimeError("The logger is not active")
|
638
|
+
namespace = self.resolve_namespace(namespace)
|
639
|
+
|
640
|
+
@functools.lru_cache(maxsize=None)
|
641
|
+
def image_future() -> PILImage:
|
642
|
+
image, label = value() if callable(value) else value
|
643
|
+
image = get_image(image, target_resolution)
|
644
|
+
return image_with_text(
|
645
|
+
image,
|
646
|
+
standardize_text(label, max_line_length),
|
647
|
+
max_num_lines=max_num_lines,
|
648
|
+
line_spacing=line_spacing,
|
649
|
+
centered=centered,
|
650
|
+
)
|
651
|
+
|
652
|
+
self.images[namespace][key] = image_future
|
653
|
+
|
654
|
+
def log_images(
|
655
|
+
self,
|
656
|
+
key: str,
|
657
|
+
value: (
|
658
|
+
Callable[[], Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array]
|
659
|
+
| Sequence[np.ndarray | Array | PILImage]
|
660
|
+
| np.ndarray
|
661
|
+
| Array
|
662
|
+
),
|
663
|
+
*,
|
664
|
+
namespace: str | None = None,
|
665
|
+
max_images: int | None = None,
|
666
|
+
target_resolution: tuple[int, int] | None = (256, 256),
|
667
|
+
sep: int = 0,
|
668
|
+
) -> None:
|
669
|
+
"""Logs a set of images.
|
670
|
+
|
671
|
+
The images are tiled to be nearly-square.
|
672
|
+
|
673
|
+
Args:
|
674
|
+
key: The key being logged
|
675
|
+
value: The images being logged
|
676
|
+
namespace: An optional logging namespace
|
677
|
+
max_images: The maximum number of images to show; extra images
|
678
|
+
are clipped
|
679
|
+
target_resolution: The target resolution for each image; if None,
|
680
|
+
don't resample the images
|
681
|
+
sep: An optional separation amount between adjacent images
|
682
|
+
"""
|
683
|
+
if not self.active:
|
684
|
+
raise RuntimeError("The logger is not active")
|
685
|
+
namespace = self.resolve_namespace(namespace)
|
686
|
+
|
687
|
+
@functools.lru_cache(maxsize=None)
|
688
|
+
def images_future() -> PILImage:
|
689
|
+
images = value() if callable(value) else value
|
690
|
+
if max_images is not None:
|
691
|
+
images = images[:max_images]
|
692
|
+
if isinstance(images, Array):
|
693
|
+
images = as_numpy(images)
|
694
|
+
if isinstance(images, Sequence):
|
695
|
+
images = list(images)
|
696
|
+
images = [get_image(image, target_resolution) for image in images]
|
697
|
+
return tile_images(images, sep)
|
698
|
+
|
699
|
+
self.images[namespace][key] = images_future
|
700
|
+
|
701
|
+
def log_labeled_images(
|
702
|
+
self,
|
703
|
+
key: str,
|
704
|
+
value: (
|
705
|
+
Callable[[], tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]]
|
706
|
+
| tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]
|
707
|
+
),
|
708
|
+
*,
|
709
|
+
namespace: str | None = None,
|
710
|
+
max_images: int | None = None,
|
711
|
+
max_line_length: int | None = None,
|
712
|
+
max_num_lines: int | None = None,
|
713
|
+
target_resolution: tuple[int, int] | None = (256, 256),
|
714
|
+
line_spacing: int = 2,
|
715
|
+
centered: bool = True,
|
716
|
+
sep: int = 0,
|
717
|
+
) -> None:
|
718
|
+
"""Logs a set of images with labels.
|
719
|
+
|
720
|
+
The images are tiled to be nearly-square.
|
721
|
+
|
722
|
+
Args:
|
723
|
+
key: The key being logged
|
724
|
+
value: The images and labels being logged
|
725
|
+
namespace: An optional logging namespace
|
726
|
+
max_images: The maximum number of images to show; extra images
|
727
|
+
are clipped
|
728
|
+
max_line_length: The maximum line length for the label
|
729
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
730
|
+
of the image
|
731
|
+
target_resolution: The target resolution for each image; if None,
|
732
|
+
don't resample the images
|
733
|
+
line_spacing: The spacing between adjacent lines
|
734
|
+
centered: If set, center the text labels, otherwise align to the left
|
735
|
+
sep: An optional separation amount between adjacent images
|
736
|
+
"""
|
737
|
+
if not self.active:
|
738
|
+
raise RuntimeError("The logger is not active")
|
739
|
+
namespace = self.resolve_namespace(namespace)
|
740
|
+
|
741
|
+
@functools.lru_cache(maxsize=None)
|
742
|
+
def images_future() -> PILImage:
|
743
|
+
images, labels = value() if callable(value) else value
|
744
|
+
if max_images is not None:
|
745
|
+
images = images[:max_images]
|
746
|
+
labels = labels[:max_images]
|
747
|
+
images = [get_image(image, target_resolution) for image in images]
|
748
|
+
images = [
|
749
|
+
image_with_text(
|
750
|
+
image,
|
751
|
+
standardize_text(label, max_line_length),
|
752
|
+
max_num_lines=max_num_lines,
|
753
|
+
line_spacing=line_spacing,
|
754
|
+
centered=centered,
|
755
|
+
)
|
756
|
+
for image, label in zip(images, labels)
|
757
|
+
]
|
758
|
+
return tile_images(images, sep)
|
759
|
+
|
760
|
+
self.images[namespace][key] = images_future
|
761
|
+
|
762
|
+
def log_git_state(self, git_state: str) -> None:
|
763
|
+
for logger in self.loggers:
|
764
|
+
logger.log_git_state(git_state)
|
765
|
+
|
766
|
+
def log_training_code(self, training_code: str) -> None:
|
767
|
+
for logger in self.loggers:
|
768
|
+
logger.log_training_code(training_code)
|
769
|
+
|
770
|
+
def log_config(self, config: DictConfig) -> None:
|
771
|
+
for logger in self.loggers:
|
772
|
+
logger.log_config(config)
|
773
|
+
|
774
|
+
def __enter__(self) -> Self:
|
775
|
+
self.active = True
|
776
|
+
for logger in self.loggers:
|
777
|
+
logger.start()
|
778
|
+
return self
|
779
|
+
|
780
|
+
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
|
781
|
+
for logger in self.loggers:
|
782
|
+
logger.stop()
|
783
|
+
self.active = False
|