xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/task/logger.py ADDED
@@ -0,0 +1,783 @@
1
+ """Defines the core logger.
2
+
3
+ A common problem when quickly prototyping ML models is nicely logging images,
4
+ videos, audio, or other data. Additionally, logging on every step can be
5
+ overwhelming. This logger implements a number of convenience functions to
6
+ take heterogeneous input data and put it into a standard format, which can
7
+ then be used by downstream loggers to actually log the data. For example, this
8
+ logger will automatically tile multiple images into a single image, add
9
+ captions to images, and so on.
10
+ """
11
+
12
+ import functools
13
+ import logging
14
+ import math
15
+ import re
16
+ import time
17
+ from abc import ABC, abstractmethod
18
+ from collections import defaultdict
19
+ from dataclasses import dataclass
20
+ from types import TracebackType
21
+ from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+ from jaxtyping import Array
27
+ from omegaconf import DictConfig
28
+ from PIL import Image, ImageDraw, ImageFont
29
+ from PIL.Image import Image as PILImage
30
+
31
+ from xax.core.state import Phase, State
32
+ from xax.utils.experiments import IntervalTicker
33
+ from xax.utils.logging import LOG_ERROR_SUMMARY, LOG_PING, LOG_STATUS
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ T = TypeVar("T")
38
+ LogT = TypeVar("LogT")
39
+ Number = int | float | Array | np.ndarray
40
+
41
+ ChannelSelectMode = Literal["first", "last", "mean"]
42
+
43
+ DEFAULT_NAMESPACE = "value"
44
+
45
+ NAMESPACE_STACK: list[str] = []
46
+
47
+
48
+ def standardize_text(text: str, max_line_length: int | None = None, remove_non_ascii: bool = False) -> list[str]:
49
+ """Standardizes a text string to a list of lines.
50
+
51
+ Args:
52
+ text: The text to standardize
53
+ max_line_length: If set, truncate lines to this length
54
+ remove_non_ascii: Remove non-ASCII characters if present
55
+
56
+ Returns:
57
+ The standardized text lines
58
+ """
59
+
60
+ def _chunk_lines(text: str, max_length: int) -> Iterator[str]:
61
+ for i in range(0, len(text), max_length):
62
+ yield text[i : i + max_length]
63
+
64
+ if remove_non_ascii:
65
+ text = "".join(char for char in text if ord(char) < 128)
66
+ lines = [re.sub(r"\s+", " ", line) for line in re.split(r"[\n\r]+", text.strip())]
67
+ if max_line_length is not None:
68
+ lines = [subline for line in lines for subline in _chunk_lines(line, max_line_length)]
69
+ return lines
70
+
71
+
72
+ def make_human_viewable_resolution(
73
+ image: PILImage,
74
+ interpolation: Image.Resampling = Image.Resampling.LANCZOS,
75
+ trg_res: tuple[int, int] = (512, 512),
76
+ ) -> PILImage:
77
+ """Resizes image to human-viewable resolution.
78
+
79
+ Args:
80
+ image: The image to resize, with shape (C, H, W)
81
+ interpolation: Interpolation mode to use for image resizing
82
+ trg_res: The target image resolution; the image will be reshaped to
83
+ have approximately the same area as an image with this resolution
84
+
85
+ Returns:
86
+ The resized image
87
+ """
88
+ width, height = image.size
89
+ trg_height, trg_width = trg_res
90
+ factor = math.sqrt((trg_height * trg_width) / (height * width))
91
+ new_height, new_width = int(height * factor), int(width * factor)
92
+ return image.resize((new_width, new_height), interpolation)
93
+
94
+
95
+ def image_with_text(
96
+ image: PILImage,
97
+ text: list[str],
98
+ max_num_lines: int | None,
99
+ line_spacing: int,
100
+ centered: bool,
101
+ ) -> PILImage:
102
+ """Adds a text label to an image.
103
+
104
+ Args:
105
+ image: The image to label, with shape (C, H, W)
106
+ text: The text label for the image
107
+ max_num_lines: The number of lines of spacing to add to the bottom
108
+ of the image
109
+ line_spacing: The spacing between adjacent lines
110
+ centered: If set, center the text labels, otherwise align to the left
111
+
112
+ Returns:
113
+ The image with a text label
114
+ """
115
+ if not text:
116
+ return image
117
+ if max_num_lines is None:
118
+ max_num_lines = len(text)
119
+ else:
120
+ text = text[:max_num_lines]
121
+ width, height = image.size
122
+ font: ImageFont.ImageFont = ImageFont.load_default()
123
+ _, _, _, line_height = font.getbbox(text[0])
124
+ new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
125
+ padded_image = Image.new(image.mode, (new_width, new_height), 255)
126
+ padded_image.paste(image, (0, 0))
127
+ drawer = ImageDraw.Draw(padded_image)
128
+ for i, text_line in enumerate(text):
129
+ text_line_top = height + line_spacing + i * (line_height + line_spacing)
130
+ if centered:
131
+ _, _, line_width, _ = font.getbbox(text_line)
132
+ text_line_left = (width - line_width) / 2
133
+ drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
134
+ else:
135
+ drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
136
+ return padded_image
137
+
138
+
139
+ class namespace_context: # noqa: N801
140
+ def __init__(self, name: str | None) -> None:
141
+ self._name = name
142
+ self._prev_stack: list[str] | None = None
143
+
144
+ def __enter__(self) -> None:
145
+ if self._name is None:
146
+ self._prev_stack = NAMESPACE_STACK[:]
147
+ NAMESPACE_STACK.clear()
148
+ else:
149
+ NAMESPACE_STACK.append(self._name)
150
+
151
+ def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
152
+ if self._prev_stack is not None:
153
+ NAMESPACE_STACK[:] = self._prev_stack
154
+ else:
155
+ NAMESPACE_STACK.pop()
156
+
157
+
158
+ def normalize(x: np.ndarray) -> np.ndarray:
159
+ return (x - x.min()) / (x.max() - x.min())
160
+
161
+
162
+ def ternary_search_optimal_side_counts(height: int, width: int, count: int) -> tuple[int, int]:
163
+ min_factors = [i for i in range(1, math.ceil(math.sqrt(count)) + 1) if count % i == 0]
164
+ max_factors = [i for i in min_factors[::-1] if i * i != count]
165
+ factors = [(i, count // i) for i in min_factors] + [(count // i, i) for i in max_factors]
166
+
167
+ lo, hi = 0, len(factors) - 1
168
+
169
+ def penalty(i: int) -> float:
170
+ hval, wval = factors[i]
171
+ h, w = hval * height, wval * width
172
+ return -(min(h, w) ** 2)
173
+
174
+ # Runs ternary search to minimize penalty.
175
+ while lo < hi - 2:
176
+ lmid, rmid = (lo * 2 + hi) // 3, (lo + hi * 2) // 3
177
+ if penalty(lmid) > penalty(rmid):
178
+ lo = lmid
179
+ else:
180
+ hi = rmid
181
+
182
+ # Returns the lowest-penalty configuration.
183
+ mid = (lo + hi) // 2
184
+ plo, pmid, phi = penalty(lo), penalty(mid), penalty(hi)
185
+
186
+ if pmid <= plo and pmid <= phi:
187
+ return factors[mid]
188
+ elif plo <= phi:
189
+ return factors[lo]
190
+ else:
191
+ return factors[hi]
192
+
193
+
194
+ def tile_images_different_sizes(images: list[PILImage], sep: int) -> PILImage:
195
+ """Tiles a list of images into a single image, even if they have different sizes.
196
+
197
+ Args:
198
+ images: The images to tile.
199
+ sep: The separation between adjacent images.
200
+
201
+ Returns:
202
+ The tiled image.
203
+ """
204
+ total_width, max_height = sum(image.width for image in images), max(image.height for image in images)
205
+ tiled = Image.new("RGB", (total_width + (len(images) - 1) * sep, max_height))
206
+ x = 0
207
+ for image in images:
208
+ tiled.paste(image, (x, 0))
209
+ x += image.width + sep
210
+ return tiled
211
+
212
+
213
+ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
214
+ """Tiles a list of images into a single image.
215
+
216
+ Args:
217
+ images: The images to tile.
218
+ sep: The separation between adjacent images.
219
+
220
+ Returns:
221
+ The tiled image.
222
+ """
223
+ if not images:
224
+ return Image.new("RGB", (0, 0))
225
+
226
+ # Gets the optimal side counts.
227
+ height, width = images[0].height, images[0].width
228
+ if not all(image.size == images[0].size for image in images):
229
+ return tile_images_different_sizes(images, sep)
230
+
231
+ hside, wside = ternary_search_optimal_side_counts(height, width, len(images))
232
+
233
+ # Tiles the images.
234
+ tiled = Image.new("RGB", (wside * width + (wside - 1) * sep, hside * height + (hside - 1) * sep))
235
+ for i, image in enumerate(images):
236
+ x, y = i % wside, i // wside
237
+ tiled.paste(image, (x * (width + sep), y * (height + sep)))
238
+
239
+ return tiled
240
+
241
+
242
+ def as_numpy(array: Array) -> np.ndarray:
243
+ array = jax.device_get(array)
244
+ if jax.dtypes.issubdtype(array.dtype, jnp.floating):
245
+ array = array.astype(jnp.float32)
246
+ elif jax.dtypes.issubdtype(array.dtype, jnp.integer):
247
+ array = array.astype(jnp.int32)
248
+ elif jax.dtypes.issubdtype(array.dtype, jnp.bool_):
249
+ array = array.astype(jnp.bool_)
250
+ return np.array(array)
251
+
252
+
253
+ def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int, int] | None = None) -> PILImage:
254
+ if not isinstance(image, (np.ndarray, Array, PILImage)):
255
+ raise ValueError(f"Unsupported image type: {type(image)}")
256
+ if isinstance(image, Array):
257
+ image = as_numpy(image)
258
+ if isinstance(image, np.ndarray):
259
+ if image.ndim == 2:
260
+ image = np.expand_dims(image, axis=-1)
261
+ if image.ndim != 3:
262
+ raise RuntimeError(f"Expected image to have shape HW, HWC, or CHW, got {image.shape}")
263
+
264
+ # Normalizes the image and converts to integer.
265
+ if np.issubdtype(image.dtype, np.floating):
266
+ image = (normalize(image) * 255).round().astype(np.uint8)
267
+ elif image.dtype == np.uint8:
268
+ pass
269
+ else:
270
+ raise ValueError(f"Unsupported image dtype: {image.dtype}")
271
+
272
+ # Converts to a PIL image.
273
+ if image.shape[-1] == 1:
274
+ image = Image.fromarray(image[..., 0])
275
+ elif image.shape[-1] == 3:
276
+ image = Image.fromarray(image)
277
+ elif image.shape[0] == 1:
278
+ image = Image.fromarray(image[0])
279
+ elif image.shape[0] == 3:
280
+ image = Image.fromarray(image.transpose(1, 2, 0))
281
+ else:
282
+ raise ValueError(f"Unsupported image shape: {image.shape}")
283
+
284
+ if target_resolution is not None:
285
+ image = make_human_viewable_resolution(image, trg_res=target_resolution)
286
+ return image
287
+
288
+
289
+ @dataclass
290
+ class LogImage:
291
+ image: PILImage
292
+
293
+
294
+ @dataclass
295
+ class LogLine:
296
+ state: State
297
+ scalars: dict[str, dict[str, Number]]
298
+ strings: dict[str, dict[str, str]]
299
+ images: dict[str, dict[str, LogImage]]
300
+
301
+
302
+ @dataclass
303
+ class LogErrorSummary:
304
+ message: str
305
+
306
+
307
+ @dataclass
308
+ class LogError:
309
+ message: str
310
+ location: str | None = None
311
+
312
+ @property
313
+ def message_with_location(self) -> str:
314
+ message = self.message
315
+ if self.location is not None:
316
+ message += f" ({self.location})"
317
+ return message
318
+
319
+
320
+ @dataclass
321
+ class LogStatus:
322
+ message: str
323
+ created: float
324
+ filename: str | None = None
325
+ lineno: int | None = None
326
+
327
+
328
+ @dataclass
329
+ class LogPing:
330
+ message: str
331
+ created: float
332
+ filename: str | None = None
333
+ lineno: int | None = None
334
+
335
+
336
+ class LoggerImpl(ABC):
337
+ def __init__(self, log_interval_seconds: float = 1.0) -> None:
338
+ """Defines some default behavior for loggers.
339
+
340
+ Every logger needs to implement the ``write`` function, which handles
341
+ actually writing the logs to wherever they needs to go. The basic
342
+ class implements a simple interval-based logging scheme to avoid
343
+ writing too many log lines.
344
+
345
+ Args:
346
+ log_interval_seconds: The interval between successive log lines.
347
+ """
348
+ super().__init__()
349
+
350
+ self.tickers = {phase: IntervalTicker(log_interval_seconds) for phase in get_args(Phase)}
351
+
352
+ def start(self) -> None:
353
+ pass
354
+
355
+ def stop(self) -> None:
356
+ pass
357
+
358
+ @abstractmethod
359
+ def write(self, line: LogLine) -> None:
360
+ """Handles writing the current log line.
361
+
362
+ Args:
363
+ line: The line to write.
364
+ """
365
+
366
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
367
+ """Handles writing an error summary.
368
+
369
+ Args:
370
+ error_summary: The error summary to write.
371
+ """
372
+
373
+ def write_error(self, error: LogError) -> None:
374
+ """Handles writing an error line.
375
+
376
+ Args:
377
+ error: The error information to write.
378
+ """
379
+
380
+ def write_status(self, status: LogStatus) -> None:
381
+ """Handles writing a status line.
382
+
383
+ Args:
384
+ status: The status to write.
385
+ """
386
+
387
+ def write_ping(self, ping: LogPing) -> None:
388
+ """Handles writing a ping line.
389
+
390
+ Args:
391
+ ping: The ping to write.
392
+ """
393
+
394
+ def log_git_state(self, git_state: str) -> None:
395
+ """Logs Git state for the current run.
396
+
397
+ Args:
398
+ git_state: The Git state, as text blocks.
399
+ """
400
+
401
+ def log_training_code(self, training_code: str) -> None:
402
+ """Logs the training script code.
403
+
404
+ Args:
405
+ training_code: The training script code.
406
+ """
407
+
408
+ def log_config(self, config: DictConfig) -> None:
409
+ """Logs the configuration for the current run.
410
+
411
+ Args:
412
+ config: The configuration, as a DictConfig.
413
+ """
414
+
415
+ def should_log(self, state: State) -> bool:
416
+ """Function that determines if the logger should log the current step.
417
+
418
+ Args:
419
+ state: The current step's state.
420
+
421
+ Returns:
422
+ If the logger should log the current step.
423
+ """
424
+ return self.tickers[state.phase].tick(state.elapsed_time_s)
425
+
426
+
427
+ class ToastHandler(logging.Handler):
428
+ def __init__(self, logger: "Logger") -> None:
429
+ super().__init__()
430
+
431
+ self.logger = logger
432
+
433
+ def emit(self, record: logging.LogRecord) -> None:
434
+ try:
435
+ if record.levelno == LOG_ERROR_SUMMARY:
436
+ self.logger.write_error_summary(record.getMessage())
437
+ elif record.levelno == LOG_STATUS:
438
+ self.logger.write_status(record.getMessage(), record.filename, record.lineno)
439
+ elif record.levelno in (LOG_PING, logging.WARNING):
440
+ self.logger.write_ping(record.getMessage(), record.filename, record.lineno)
441
+ elif record.levelno in (logging.ERROR, logging.CRITICAL, logging.WARNING):
442
+ self.logger.write_error(record.getMessage(), f"{record.filename}:{record.lineno}")
443
+ except RecursionError:
444
+ raise
445
+ except Exception:
446
+ self.handleError(record)
447
+
448
+ def add_for_logger(self, logger: logging.Logger) -> None:
449
+ # Removes existing ToastHandler.
450
+ handlers_to_remove = []
451
+ for handler in logger.handlers:
452
+ if isinstance(handler, ToastHandler):
453
+ handlers_to_remove.append(handler)
454
+ for handler in handlers_to_remove:
455
+ logger.removeHandler(handler)
456
+
457
+ # Adds the new ToastHandler.
458
+ logger.addHandler(self)
459
+
460
+
461
+ class Logger:
462
+ """Defines an intermediate container which holds values to log somewhere else."""
463
+
464
+ def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
465
+ self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
466
+ self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
467
+ self.images: dict[str, dict[str, Callable[[], PILImage]]] = defaultdict(dict)
468
+ self.default_namespace = default_namespace
469
+ self.loggers: list[LoggerImpl] = []
470
+
471
+ # Registers a logging handler to route log messages to the logger.
472
+ root_logger = logging.getLogger()
473
+ ToastHandler(self).add_for_logger(root_logger)
474
+
475
+ # Flag when the logger is active.
476
+ self.active = False
477
+
478
+ def add_logger(self, *logger: LoggerImpl) -> None:
479
+ """Add the logger, so that it gets called when `write` is called.
480
+
481
+ Args:
482
+ logger: The logger to add.
483
+ """
484
+ self.loggers.extend(logger)
485
+
486
+ def pack(self, state: State) -> LogLine:
487
+ return LogLine(
488
+ state=state,
489
+ scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
490
+ strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
491
+ images={k: {kk: LogImage(v()) for kk, v in v.items()} for k, v in self.images.items()},
492
+ )
493
+
494
+ def clear(self) -> None:
495
+ self.scalars.clear()
496
+ self.strings.clear()
497
+ self.images.clear()
498
+
499
+ def write(self, state: State) -> None:
500
+ """Writes the current step's logging information.
501
+
502
+ Args:
503
+ state: The current step's state.
504
+ """
505
+ should_log = [logger.should_log(state) for logger in self.loggers]
506
+ if not any(should_log):
507
+ self.clear()
508
+ return
509
+ line = self.pack(state)
510
+ self.clear()
511
+ for logger in (logger for logger, should_log in zip(self.loggers, should_log) if should_log):
512
+ logger.write(line)
513
+
514
+ def write_error_summary(self, error_summary: str) -> None:
515
+ for logger in self.loggers:
516
+ logger.write_error_summary(LogErrorSummary(error_summary))
517
+
518
+ def write_error(self, message: str, location: str | None = None) -> None:
519
+ for logger in self.loggers:
520
+ logger.write_error(LogError(message, location))
521
+
522
+ def write_status(
523
+ self,
524
+ message: str,
525
+ filename: str | None = None,
526
+ lineno: int | None = None,
527
+ created: float | None = None,
528
+ ) -> None:
529
+ status = LogStatus(message, time.time() if created is None else created, filename, lineno)
530
+ for logger in self.loggers:
531
+ logger.write_status(status)
532
+
533
+ def write_ping(
534
+ self,
535
+ message: str,
536
+ filename: str | None = None,
537
+ lineno: int | None = None,
538
+ created: float | None = None,
539
+ ) -> None:
540
+ ping = LogPing(message, time.time() if created is None else created, filename, lineno)
541
+ for logger in self.loggers:
542
+ logger.write_ping(ping)
543
+
544
+ def resolve_namespace(self, namespace: str | None = None) -> str:
545
+ return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
546
+
547
+ def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
548
+ """Logs a scalar value.
549
+
550
+ Args:
551
+ key: The key being logged
552
+ value: The scalar value being logged
553
+ namespace: An optional logging namespace
554
+ """
555
+ if not self.active:
556
+ raise RuntimeError("The logger is not active")
557
+ namespace = self.resolve_namespace(namespace)
558
+
559
+ @functools.lru_cache(maxsize=None)
560
+ def scalar_future() -> Number:
561
+ return value() if callable(value) else value
562
+
563
+ self.scalars[namespace][key] = scalar_future
564
+
565
+ def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
566
+ """Logs a string value.
567
+
568
+ Args:
569
+ key: The key being logged
570
+ value: The string value being logged
571
+ namespace: An optional logging namespace
572
+ """
573
+ if not self.active:
574
+ raise RuntimeError("The logger is not active")
575
+ namespace = self.resolve_namespace(namespace)
576
+
577
+ @functools.lru_cache(maxsize=None)
578
+ def value_future() -> str:
579
+ return value() if callable(value) else value
580
+
581
+ self.strings[namespace][key] = value_future
582
+
583
+ def log_image(
584
+ self,
585
+ key: str,
586
+ value: Callable[[], np.ndarray | Array | PILImage] | np.ndarray | Array | PILImage,
587
+ *,
588
+ namespace: str | None = None,
589
+ target_resolution: tuple[int, int] | None = (512, 512),
590
+ ) -> None:
591
+ """Logs an image.
592
+
593
+ Args:
594
+ key: The key being logged
595
+ value: The image being logged
596
+ namespace: An optional logging namespace
597
+ target_resolution: The target resolution for each image; if None,
598
+ don't resample the images
599
+ """
600
+ if not self.active:
601
+ raise RuntimeError("The logger is not active")
602
+ namespace = self.resolve_namespace(namespace)
603
+
604
+ @functools.lru_cache(maxsize=None)
605
+ def image_future() -> PILImage:
606
+ return get_image(value() if callable(value) else value, target_resolution)
607
+
608
+ self.images[namespace][key] = image_future
609
+
610
+ def log_labeled_image(
611
+ self,
612
+ key: str,
613
+ value: Callable[[], tuple[np.ndarray | Array | PILImage, str]] | tuple[np.ndarray | Array | PILImage, str],
614
+ *,
615
+ namespace: str | None = None,
616
+ max_line_length: int | None = None,
617
+ max_num_lines: int | None = None,
618
+ target_resolution: tuple[int, int] | None = (512, 512),
619
+ line_spacing: int = 2,
620
+ centered: bool = True,
621
+ ) -> None:
622
+ """Logs an image with a label.
623
+
624
+ Args:
625
+ key: The key being logged
626
+ value: The image and label being logged
627
+ namespace: An optional logging namespace
628
+ max_line_length: The maximum line length for the label
629
+ max_num_lines: The number of lines of spacing to add to the bottom
630
+ of the image
631
+ target_resolution: The target resolution for each image; if None,
632
+ don't resample the images
633
+ line_spacing: The spacing between adjacent lines
634
+ centered: If set, center the text labels, otherwise align to the left
635
+ """
636
+ if not self.active:
637
+ raise RuntimeError("The logger is not active")
638
+ namespace = self.resolve_namespace(namespace)
639
+
640
+ @functools.lru_cache(maxsize=None)
641
+ def image_future() -> PILImage:
642
+ image, label = value() if callable(value) else value
643
+ image = get_image(image, target_resolution)
644
+ return image_with_text(
645
+ image,
646
+ standardize_text(label, max_line_length),
647
+ max_num_lines=max_num_lines,
648
+ line_spacing=line_spacing,
649
+ centered=centered,
650
+ )
651
+
652
+ self.images[namespace][key] = image_future
653
+
654
+ def log_images(
655
+ self,
656
+ key: str,
657
+ value: (
658
+ Callable[[], Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array]
659
+ | Sequence[np.ndarray | Array | PILImage]
660
+ | np.ndarray
661
+ | Array
662
+ ),
663
+ *,
664
+ namespace: str | None = None,
665
+ max_images: int | None = None,
666
+ target_resolution: tuple[int, int] | None = (256, 256),
667
+ sep: int = 0,
668
+ ) -> None:
669
+ """Logs a set of images.
670
+
671
+ The images are tiled to be nearly-square.
672
+
673
+ Args:
674
+ key: The key being logged
675
+ value: The images being logged
676
+ namespace: An optional logging namespace
677
+ max_images: The maximum number of images to show; extra images
678
+ are clipped
679
+ target_resolution: The target resolution for each image; if None,
680
+ don't resample the images
681
+ sep: An optional separation amount between adjacent images
682
+ """
683
+ if not self.active:
684
+ raise RuntimeError("The logger is not active")
685
+ namespace = self.resolve_namespace(namespace)
686
+
687
+ @functools.lru_cache(maxsize=None)
688
+ def images_future() -> PILImage:
689
+ images = value() if callable(value) else value
690
+ if max_images is not None:
691
+ images = images[:max_images]
692
+ if isinstance(images, Array):
693
+ images = as_numpy(images)
694
+ if isinstance(images, Sequence):
695
+ images = list(images)
696
+ images = [get_image(image, target_resolution) for image in images]
697
+ return tile_images(images, sep)
698
+
699
+ self.images[namespace][key] = images_future
700
+
701
+ def log_labeled_images(
702
+ self,
703
+ key: str,
704
+ value: (
705
+ Callable[[], tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]]
706
+ | tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]
707
+ ),
708
+ *,
709
+ namespace: str | None = None,
710
+ max_images: int | None = None,
711
+ max_line_length: int | None = None,
712
+ max_num_lines: int | None = None,
713
+ target_resolution: tuple[int, int] | None = (256, 256),
714
+ line_spacing: int = 2,
715
+ centered: bool = True,
716
+ sep: int = 0,
717
+ ) -> None:
718
+ """Logs a set of images with labels.
719
+
720
+ The images are tiled to be nearly-square.
721
+
722
+ Args:
723
+ key: The key being logged
724
+ value: The images and labels being logged
725
+ namespace: An optional logging namespace
726
+ max_images: The maximum number of images to show; extra images
727
+ are clipped
728
+ max_line_length: The maximum line length for the label
729
+ max_num_lines: The number of lines of spacing to add to the bottom
730
+ of the image
731
+ target_resolution: The target resolution for each image; if None,
732
+ don't resample the images
733
+ line_spacing: The spacing between adjacent lines
734
+ centered: If set, center the text labels, otherwise align to the left
735
+ sep: An optional separation amount between adjacent images
736
+ """
737
+ if not self.active:
738
+ raise RuntimeError("The logger is not active")
739
+ namespace = self.resolve_namespace(namespace)
740
+
741
+ @functools.lru_cache(maxsize=None)
742
+ def images_future() -> PILImage:
743
+ images, labels = value() if callable(value) else value
744
+ if max_images is not None:
745
+ images = images[:max_images]
746
+ labels = labels[:max_images]
747
+ images = [get_image(image, target_resolution) for image in images]
748
+ images = [
749
+ image_with_text(
750
+ image,
751
+ standardize_text(label, max_line_length),
752
+ max_num_lines=max_num_lines,
753
+ line_spacing=line_spacing,
754
+ centered=centered,
755
+ )
756
+ for image, label in zip(images, labels)
757
+ ]
758
+ return tile_images(images, sep)
759
+
760
+ self.images[namespace][key] = images_future
761
+
762
+ def log_git_state(self, git_state: str) -> None:
763
+ for logger in self.loggers:
764
+ logger.log_git_state(git_state)
765
+
766
+ def log_training_code(self, training_code: str) -> None:
767
+ for logger in self.loggers:
768
+ logger.log_training_code(training_code)
769
+
770
+ def log_config(self, config: DictConfig) -> None:
771
+ for logger in self.loggers:
772
+ logger.log_config(config)
773
+
774
+ def __enter__(self) -> Self:
775
+ self.active = True
776
+ for logger in self.loggers:
777
+ logger.start()
778
+ return self
779
+
780
+ def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
781
+ for logger in self.loggers:
782
+ logger.stop()
783
+ self.active = False