xax 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/logger.py CHANGED
@@ -11,16 +11,22 @@ captions to images, and so on.
11
11
 
12
12
  import functools
13
13
  import logging
14
+ import math
15
+ import re
14
16
  import time
15
17
  from abc import ABC, abstractmethod
16
18
  from collections import defaultdict
17
19
  from dataclasses import dataclass
18
20
  from types import TracebackType
19
- from typing import Callable, Literal, Self, Sequence, TypeVar, get_args
21
+ from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
20
22
 
23
+ import jax
24
+ import jax.numpy as jnp
21
25
  import numpy as np
22
26
  from jaxtyping import Array
23
27
  from omegaconf import DictConfig
28
+ from PIL import Image, ImageDraw, ImageFont
29
+ from PIL.Image import Image as PILImage
24
30
 
25
31
  from xax.core.state import Phase, State
26
32
  from xax.utils.experiments import IntervalTicker
@@ -34,15 +40,102 @@ Number = int | float | Array | np.ndarray
34
40
 
35
41
  ChannelSelectMode = Literal["first", "last", "mean"]
36
42
 
37
- VALID_VIDEO_CHANNEL_COUNTS = {1, 3}
38
- VALID_AUDIO_CHANNEL_COUNTS = {1, 2}
39
- TARGET_FPS = 12
40
43
  DEFAULT_NAMESPACE = "value"
41
44
 
42
-
43
45
  NAMESPACE_STACK: list[str] = []
44
46
 
45
47
 
48
+ def standardize_text(text: str, max_line_length: int | None = None, remove_non_ascii: bool = False) -> list[str]:
49
+ """Standardizes a text string to a list of lines.
50
+
51
+ Args:
52
+ text: The text to standardize
53
+ max_line_length: If set, truncate lines to this length
54
+ remove_non_ascii: Remove non-ASCII characters if present
55
+
56
+ Returns:
57
+ The standardized text lines
58
+ """
59
+
60
+ def _chunk_lines(text: str, max_length: int) -> Iterator[str]:
61
+ for i in range(0, len(text), max_length):
62
+ yield text[i : i + max_length]
63
+
64
+ if remove_non_ascii:
65
+ text = "".join(char for char in text if ord(char) < 128)
66
+ lines = [re.sub(r"\s+", " ", line) for line in re.split(r"[\n\r]+", text.strip())]
67
+ if max_line_length is not None:
68
+ lines = [subline for line in lines for subline in _chunk_lines(line, max_line_length)]
69
+ return lines
70
+
71
+
72
+ def make_human_viewable_resolution(
73
+ image: PILImage,
74
+ interpolation: Image.Resampling = Image.Resampling.LANCZOS,
75
+ trg_res: tuple[int, int] = (512, 512),
76
+ ) -> PILImage:
77
+ """Resizes image to human-viewable resolution.
78
+
79
+ Args:
80
+ image: The image to resize, with shape (C, H, W)
81
+ interpolation: Interpolation mode to use for image resizing
82
+ trg_res: The target image resolution; the image will be reshaped to
83
+ have approximately the same area as an image with this resolution
84
+
85
+ Returns:
86
+ The resized image
87
+ """
88
+ width, height = image.size
89
+ trg_height, trg_width = trg_res
90
+ factor = math.sqrt((trg_height * trg_width) / (height * width))
91
+ new_height, new_width = int(height * factor), int(width * factor)
92
+ return image.resize((new_width, new_height), interpolation)
93
+
94
+
95
+ def image_with_text(
96
+ image: PILImage,
97
+ text: list[str],
98
+ max_num_lines: int | None,
99
+ line_spacing: int,
100
+ centered: bool,
101
+ ) -> PILImage:
102
+ """Adds a text label to an image.
103
+
104
+ Args:
105
+ image: The image to label, with shape (C, H, W)
106
+ text: The text label for the image
107
+ max_num_lines: The number of lines of spacing to add to the bottom
108
+ of the image
109
+ line_spacing: The spacing between adjacent lines
110
+ centered: If set, center the text labels, otherwise align to the left
111
+
112
+ Returns:
113
+ The image with a text label
114
+ """
115
+ if not text:
116
+ return image
117
+ if max_num_lines is None:
118
+ max_num_lines = len(text)
119
+ else:
120
+ text = text[:max_num_lines]
121
+ width, height = image.size
122
+ font: ImageFont.ImageFont = ImageFont.load_default()
123
+ _, _, _, line_height = font.getbbox(text[0])
124
+ new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
125
+ padded_image = Image.new(image.mode, (new_width, new_height), 255)
126
+ padded_image.paste(image, (0, 0))
127
+ drawer = ImageDraw.Draw(padded_image)
128
+ for i, text_line in enumerate(text):
129
+ text_line_top = height + line_spacing + i * (line_height + line_spacing)
130
+ if centered:
131
+ _, _, line_width, _ = font.getbbox(text_line)
132
+ text_line_left = (width - line_width) / 2
133
+ drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
134
+ else:
135
+ drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
136
+ return padded_image
137
+
138
+
46
139
  class namespace_context: # noqa: N801
47
140
  def __init__(self, name: str | None) -> None:
48
141
  self._name = name
@@ -62,26 +155,140 @@ class namespace_context: # noqa: N801
62
155
  NAMESPACE_STACK.pop()
63
156
 
64
157
 
65
- @dataclass
66
- class LogImage:
67
- pixels: Array
158
+ def normalize(x: np.ndarray) -> np.ndarray:
159
+ return (x - x.min()) / (x.max() - x.min())
68
160
 
69
161
 
70
- @dataclass
71
- class LogAudio:
72
- frames: Array
73
- sample_rate: int
162
+ def ternary_search_optimal_side_counts(height: int, width: int, count: int) -> tuple[int, int]:
163
+ min_factors = [i for i in range(1, math.ceil(math.sqrt(count)) + 1) if count % i == 0]
164
+ max_factors = [i for i in min_factors[::-1] if i * i != count]
165
+ factors = [(i, count // i) for i in min_factors] + [(count // i, i) for i in max_factors]
74
166
 
167
+ lo, hi = 0, len(factors) - 1
75
168
 
76
- @dataclass
77
- class LogVideo:
78
- frames: Array
169
+ def penalty(i: int) -> float:
170
+ hval, wval = factors[i]
171
+ h, w = hval * height, wval * width
172
+ return -(min(h, w) ** 2)
173
+
174
+ # Runs ternary search to minimize penalty.
175
+ while lo < hi - 2:
176
+ lmid, rmid = (lo * 2 + hi) // 3, (lo + hi * 2) // 3
177
+ if penalty(lmid) > penalty(rmid):
178
+ lo = lmid
179
+ else:
180
+ hi = rmid
181
+
182
+ # Returns the lowest-penalty configuration.
183
+ mid = (lo + hi) // 2
184
+ plo, pmid, phi = penalty(lo), penalty(mid), penalty(hi)
185
+
186
+ if pmid <= plo and pmid <= phi:
187
+ return factors[mid]
188
+ elif plo <= phi:
189
+ return factors[lo]
190
+ else:
191
+ return factors[hi]
192
+
193
+
194
+ def tile_images_different_sizes(images: list[PILImage], sep: int) -> PILImage:
195
+ """Tiles a list of images into a single image, even if they have different sizes.
196
+
197
+ Args:
198
+ images: The images to tile.
199
+ sep: The separation between adjacent images.
200
+
201
+ Returns:
202
+ The tiled image.
203
+ """
204
+ total_width, max_height = sum(image.width for image in images), max(image.height for image in images)
205
+ tiled = Image.new("RGB", (total_width + (len(images) - 1) * sep, max_height))
206
+ x = 0
207
+ for image in images:
208
+ tiled.paste(image, (x, 0))
209
+ x += image.width + sep
210
+ return tiled
211
+
212
+
213
+ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
214
+ """Tiles a list of images into a single image.
215
+
216
+ Args:
217
+ images: The images to tile.
218
+ sep: The separation between adjacent images.
219
+
220
+ Returns:
221
+ The tiled image.
222
+ """
223
+ if not images:
224
+ return Image.new("RGB", (0, 0))
225
+
226
+ # Gets the optimal side counts.
227
+ height, width = images[0].height, images[0].width
228
+ if not all(image.size == images[0].size for image in images):
229
+ return tile_images_different_sizes(images, sep)
230
+
231
+ hside, wside = ternary_search_optimal_side_counts(height, width, len(images))
232
+
233
+ # Tiles the images.
234
+ tiled = Image.new("RGB", (wside * width + (wside - 1) * sep, hside * height + (hside - 1) * sep))
235
+ for i, image in enumerate(images):
236
+ x, y = i % wside, i // wside
237
+ tiled.paste(image, (x * (width + sep), y * (height + sep)))
238
+
239
+ return tiled
240
+
241
+
242
+ def as_numpy(array: Array) -> np.ndarray:
243
+ array = jax.device_get(array)
244
+ if jax.dtypes.issubdtype(array.dtype, jnp.floating):
245
+ array = array.astype(jnp.float32)
246
+ elif jax.dtypes.issubdtype(array.dtype, jnp.integer):
247
+ array = array.astype(jnp.int32)
248
+ elif jax.dtypes.issubdtype(array.dtype, jnp.bool_):
249
+ array = array.astype(jnp.bool_)
250
+ return np.array(array)
251
+
252
+
253
+ def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int, int] | None = None) -> PILImage:
254
+ if not isinstance(image, (np.ndarray, Array, PILImage)):
255
+ raise ValueError(f"Unsupported image type: {type(image)}")
256
+ if isinstance(image, Array):
257
+ image = as_numpy(image)
258
+ if isinstance(image, np.ndarray):
259
+ if image.ndim == 2:
260
+ image = np.expand_dims(image, axis=-1)
261
+ if image.ndim != 3:
262
+ raise RuntimeError(f"Expected image to have shape HW, HWC, or CHW, got {image.shape}")
263
+
264
+ # Normalizes the image and converts to integer.
265
+ if np.issubdtype(image.dtype, np.floating):
266
+ image = (normalize(image) * 255).round().astype(np.uint8)
267
+ elif image.dtype == np.uint8:
268
+ pass
269
+ else:
270
+ raise ValueError(f"Unsupported image dtype: {image.dtype}")
271
+
272
+ # Converts to a PIL image.
273
+ if image.shape[-1] == 1:
274
+ image = Image.fromarray(image[..., 0])
275
+ elif image.shape[-1] == 3:
276
+ image = Image.fromarray(image)
277
+ elif image.shape[0] == 1:
278
+ image = Image.fromarray(image[0])
279
+ elif image.shape[0] == 3:
280
+ image = Image.fromarray(image.transpose(1, 2, 0))
281
+ else:
282
+ raise ValueError(f"Unsupported image shape: {image.shape}")
283
+
284
+ if target_resolution is not None:
285
+ image = make_human_viewable_resolution(image, trg_res=target_resolution)
286
+ return image
79
287
 
80
288
 
81
289
  @dataclass
82
- class LogPointCloud:
83
- xyz: Array
84
- colors: Array | None
290
+ class LogImage:
291
+ image: PILImage
85
292
 
86
293
 
87
294
  @dataclass
@@ -90,9 +297,6 @@ class LogLine:
90
297
  scalars: dict[str, dict[str, Number]]
91
298
  strings: dict[str, dict[str, str]]
92
299
  images: dict[str, dict[str, LogImage]]
93
- audios: dict[str, dict[str, LogAudio]]
94
- videos: dict[str, dict[str, LogVideo]]
95
- point_cloud: dict[str, dict[str, LogPointCloud]]
96
300
 
97
301
 
98
302
  @dataclass
@@ -260,11 +464,7 @@ class Logger:
260
464
  def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
261
465
  self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
262
466
  self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
263
- self.images: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
264
- self.audio: dict[str, dict[str, Callable[[], tuple[Array, int]]]] = defaultdict(dict)
265
- self.videos: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
266
- self.histograms: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
267
- self.point_clouds: dict[str, dict[str, Callable[[], tuple[Array, Array | None]]]] = defaultdict(dict)
467
+ self.images: dict[str, dict[str, Callable[[], PILImage]]] = defaultdict(dict)
268
468
  self.default_namespace = default_namespace
269
469
  self.loggers: list[LoggerImpl] = []
270
470
 
@@ -272,6 +472,9 @@ class Logger:
272
472
  root_logger = logging.getLogger()
273
473
  ToastHandler(self).add_for_logger(root_logger)
274
474
 
475
+ # Flag when the logger is active.
476
+ self.active = False
477
+
275
478
  def add_logger(self, *logger: LoggerImpl) -> None:
276
479
  """Add the logger, so that it gets called when `write` is called.
277
480
 
@@ -286,19 +489,12 @@ class Logger:
286
489
  scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
287
490
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
288
491
  images={k: {kk: LogImage(v()) for kk, v in v.items()} for k, v in self.images.items()},
289
- audios={k: {kk: LogAudio(*v()) for kk, v in v.items()} for k, v in self.audio.items()},
290
- videos={k: {kk: LogVideo(v()) for kk, v in v.items()} for k, v in self.videos.items()},
291
- point_cloud={k: {kk: LogPointCloud(*v()) for kk, v in v.items()} for k, v in self.point_clouds.items()},
292
492
  )
293
493
 
294
494
  def clear(self) -> None:
295
495
  self.scalars.clear()
296
496
  self.strings.clear()
297
497
  self.images.clear()
298
- self.audio.clear()
299
- self.videos.clear()
300
- self.histograms.clear()
301
- self.point_clouds.clear()
302
498
 
303
499
  def write(self, state: State) -> None:
304
500
  """Writes the current step's logging information.
@@ -356,6 +552,8 @@ class Logger:
356
552
  value: The scalar value being logged
357
553
  namespace: An optional logging namespace
358
554
  """
555
+ if not self.active:
556
+ raise RuntimeError("The logger is not active")
359
557
  namespace = self.resolve_namespace(namespace)
360
558
 
361
559
  @functools.lru_cache(maxsize=None)
@@ -372,6 +570,8 @@ class Logger:
372
570
  value: The string value being logged
373
571
  namespace: An optional logging namespace
374
572
  """
573
+ if not self.active:
574
+ raise RuntimeError("The logger is not active")
375
575
  namespace = self.resolve_namespace(namespace)
376
576
 
377
577
  @functools.lru_cache(maxsize=None)
@@ -383,71 +583,87 @@ class Logger:
383
583
  def log_image(
384
584
  self,
385
585
  key: str,
386
- value: Callable[[], Array] | Array,
586
+ value: Callable[[], np.ndarray | Array | PILImage] | np.ndarray | Array | PILImage,
387
587
  *,
388
588
  namespace: str | None = None,
389
- keep_resolution: bool = False,
589
+ target_resolution: tuple[int, int] | None = (512, 512),
390
590
  ) -> None:
391
591
  """Logs an image.
392
592
 
393
593
  Args:
394
594
  key: The key being logged
395
- value: The image being logged; can be (C, H, W), (H, W, C) or (H, W)
396
- as an RGB (3 channel) or grayscale (1 channel) image
595
+ value: The image being logged
397
596
  namespace: An optional logging namespace
398
- keep_resolution: If set, keep the image resolution the same,
399
- otherwise upscale or downscale the image to a standard
400
- resolution
597
+ target_resolution: The target resolution for each image; if None,
598
+ don't resample the images
401
599
  """
600
+ if not self.active:
601
+ raise RuntimeError("The logger is not active")
402
602
  namespace = self.resolve_namespace(namespace)
403
603
 
404
604
  @functools.lru_cache(maxsize=None)
405
- def image_future() -> Array:
406
- raise NotImplementedError
605
+ def image_future() -> PILImage:
606
+ return get_image(value() if callable(value) else value, target_resolution)
407
607
 
408
608
  self.images[namespace][key] = image_future
409
609
 
410
610
  def log_labeled_image(
411
611
  self,
412
612
  key: str,
413
- value: Callable[[], tuple[Array, str]] | tuple[Array, str],
613
+ value: Callable[[], tuple[np.ndarray | Array | PILImage, str]] | tuple[np.ndarray | Array | PILImage, str],
414
614
  *,
415
615
  namespace: str | None = None,
416
616
  max_line_length: int | None = None,
417
- keep_resolution: bool = False,
617
+ max_num_lines: int | None = None,
618
+ target_resolution: tuple[int, int] | None = (512, 512),
619
+ line_spacing: int = 2,
418
620
  centered: bool = True,
419
621
  ) -> None:
420
622
  """Logs an image with a label.
421
623
 
422
624
  Args:
423
625
  key: The key being logged
424
- value: The image and label being logged; the image can be (C, H, W),
425
- (H, W, C) or (H, W) as an RGB (3 channel) or grayscale
426
- (1 channel) image
626
+ value: The image and label being logged
427
627
  namespace: An optional logging namespace
428
- max_line_length: Labels longer than this length are wrapped around
429
- keep_resolution: If set, keep the image resolution the same,
430
- otherwise upscale or downscale the image to a standard
431
- resolution
432
- centered: If set, center the text labels, otherwise align to the
433
- left
628
+ max_line_length: The maximum line length for the label
629
+ max_num_lines: The number of lines of spacing to add to the bottom
630
+ of the image
631
+ target_resolution: The target resolution for each image; if None,
632
+ don't resample the images
633
+ line_spacing: The spacing between adjacent lines
634
+ centered: If set, center the text labels, otherwise align to the left
434
635
  """
636
+ if not self.active:
637
+ raise RuntimeError("The logger is not active")
435
638
  namespace = self.resolve_namespace(namespace)
436
639
 
437
640
  @functools.lru_cache(maxsize=None)
438
- def labeled_image_future() -> Array:
439
- raise NotImplementedError
641
+ def image_future() -> PILImage:
642
+ image, label = value() if callable(value) else value
643
+ image = get_image(image, target_resolution)
644
+ return image_with_text(
645
+ image,
646
+ standardize_text(label, max_line_length),
647
+ max_num_lines=max_num_lines,
648
+ line_spacing=line_spacing,
649
+ centered=centered,
650
+ )
440
651
 
441
- self.images[namespace][key] = labeled_image_future
652
+ self.images[namespace][key] = image_future
442
653
 
443
654
  def log_images(
444
655
  self,
445
656
  key: str,
446
- value: Callable[[], Array] | Array,
657
+ value: (
658
+ Callable[[], Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array]
659
+ | Sequence[np.ndarray | Array | PILImage]
660
+ | np.ndarray
661
+ | Array
662
+ ),
447
663
  *,
448
664
  namespace: str | None = None,
449
- keep_resolution: bool = False,
450
665
  max_images: int | None = None,
666
+ target_resolution: tuple[int, int] | None = (256, 256),
451
667
  sep: int = 0,
452
668
  ) -> None:
453
669
  """Logs a set of images.
@@ -456,35 +672,48 @@ class Logger:
456
672
 
457
673
  Args:
458
674
  key: The key being logged
459
- value: The images being logged; can be (B, C, H, W), (B, H, W, C)
460
- or (B H, W) as an RGB (3 channel) or grayscale (1 channel) image
675
+ value: The images being logged
461
676
  namespace: An optional logging namespace
462
- keep_resolution: If set, keep the image resolution the same,
463
- otherwise upscale or downscale the image to a standard
464
- resolution
465
677
  max_images: The maximum number of images to show; extra images
466
678
  are clipped
679
+ target_resolution: The target resolution for each image; if None,
680
+ don't resample the images
467
681
  sep: An optional separation amount between adjacent images
468
682
  """
683
+ if not self.active:
684
+ raise RuntimeError("The logger is not active")
469
685
  namespace = self.resolve_namespace(namespace)
470
686
 
471
687
  @functools.lru_cache(maxsize=None)
472
- def images_future() -> Array:
473
- raise NotImplementedError
688
+ def images_future() -> PILImage:
689
+ images = value() if callable(value) else value
690
+ if max_images is not None:
691
+ images = images[:max_images]
692
+ if isinstance(images, Array):
693
+ images = as_numpy(images)
694
+ if isinstance(images, Sequence):
695
+ images = list(images)
696
+ images = [get_image(image, target_resolution) for image in images]
697
+ return tile_images(images, sep)
474
698
 
475
699
  self.images[namespace][key] = images_future
476
700
 
477
701
  def log_labeled_images(
478
702
  self,
479
703
  key: str,
480
- value: Callable[[], tuple[Array, Sequence[str]]] | tuple[Array, Sequence[str]],
704
+ value: (
705
+ Callable[[], tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]]
706
+ | tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]
707
+ ),
481
708
  *,
482
709
  namespace: str | None = None,
483
- max_line_length: int | None = None,
484
- keep_resolution: bool = False,
485
710
  max_images: int | None = None,
486
- sep: int = 0,
711
+ max_line_length: int | None = None,
712
+ max_num_lines: int | None = None,
713
+ target_resolution: tuple[int, int] | None = (256, 256),
714
+ line_spacing: int = 2,
487
715
  centered: bool = True,
716
+ sep: int = 0,
488
717
  ) -> None:
489
718
  """Logs a set of images with labels.
490
719
 
@@ -492,339 +721,43 @@ class Logger:
492
721
 
493
722
  Args:
494
723
  key: The key being logged
495
- value: The images and labels being logged; images can be
496
- (B, C, H, W), (B, H, W, C) or (B, H, W) as an RGB (3 channel)
497
- or grayscale (1 channel) image, with exactly B labels
724
+ value: The images and labels being logged
498
725
  namespace: An optional logging namespace
499
- max_line_length: Labels longer than this length are wrapped around
500
- keep_resolution: If set, keep the image resolution the same,
501
- otherwise upscale or downscale the image to a standard
502
- resolution
503
726
  max_images: The maximum number of images to show; extra images
504
727
  are clipped
728
+ max_line_length: The maximum line length for the label
729
+ max_num_lines: The number of lines of spacing to add to the bottom
730
+ of the image
731
+ target_resolution: The target resolution for each image; if None,
732
+ don't resample the images
733
+ line_spacing: The spacing between adjacent lines
734
+ centered: If set, center the text labels, otherwise align to the left
505
735
  sep: An optional separation amount between adjacent images
506
- centered: If set, center the text labels, otherwise align to the
507
- left
508
- """
509
- namespace = self.resolve_namespace(namespace)
510
-
511
- @functools.lru_cache(maxsize=None)
512
- def labeled_images_future() -> Array:
513
- raise NotImplementedError
514
-
515
- self.images[namespace][key] = labeled_images_future
516
-
517
- def log_audio(
518
- self,
519
- key: str,
520
- value: Callable[[], Array] | Array,
521
- *,
522
- namespace: str | None = None,
523
- sample_rate: int = 44100,
524
- log_spec: bool = False,
525
- n_fft_ms: float = 32.0,
526
- hop_length_ms: float | None = None,
527
- channel_select_mode: ChannelSelectMode = "first",
528
- keep_resolution: bool = False,
529
- ) -> None:
530
- """Logs an audio clip.
531
-
532
- Args:
533
- key: The key being logged
534
- value: The audio clip being logged; can be (C, T) or (T) as
535
- a mono (1 channel) or stereo (2 channel) audio clip
536
- namespace: An optional logging namespace
537
- sample_rate: The sample rate of the audio clip
538
- log_spec: If set, also log the spectrogram
539
- n_fft_ms: FFT size, in milliseconds
540
- hop_length_ms: The FFT hop length, in milliseconds
541
- channel_select_mode: How to select the channel if the audio is
542
- stereo; can be "first", "last", or "mean"; this is only used
543
- for the spectrogram
544
- keep_resolution: If set, keep the resolution of the
545
- spectrogram; otherwise, make human-viewable
546
- """
547
- namespace = self.resolve_namespace(namespace)
548
-
549
- @functools.lru_cache(maxsize=None)
550
- def raw_audio_future() -> Array:
551
- raise NotImplementedError
552
-
553
- @functools.lru_cache(maxsize=None)
554
- def audio_future() -> tuple[Array, int]:
555
- raise NotImplementedError
556
-
557
- self.audio[namespace][key] = audio_future
558
-
559
- if log_spec:
560
- # Using a unique key for the spectrogram is very important because
561
- # otherwise Tensorboard will have some issues.
562
- self.log_spectrogram(
563
- key=f"{key}_spec",
564
- value=raw_audio_future,
565
- namespace=namespace,
566
- sample_rate=sample_rate,
567
- n_fft_ms=n_fft_ms,
568
- hop_length_ms=hop_length_ms,
569
- channel_select_mode=channel_select_mode,
570
- keep_resolution=keep_resolution,
571
- )
572
-
573
- def log_audios(
574
- self,
575
- key: str,
576
- value: Callable[[], Array] | Array,
577
- *,
578
- namespace: str | None = None,
579
- sep_ms: float = 0.0,
580
- max_audios: int | None = None,
581
- sample_rate: int = 44100,
582
- log_spec: bool = False,
583
- n_fft_ms: float = 32.0,
584
- hop_length_ms: float | None = None,
585
- channel_select_mode: ChannelSelectMode = "first",
586
- spec_sep: int = 0,
587
- keep_resolution: bool = False,
588
- ) -> None:
589
- """Logs multiple audio clips.
590
-
591
- Args:
592
- key: The key being logged
593
- value: The audio clip being logged; can be (B, C, T) or (B, T) as
594
- a mono (1 channel) or stereo (2 channel) audio clip, with
595
- exactly B clips
596
- namespace: An optional logging namespace
597
- sep_ms: An optional separation amount between adjacent audio clips
598
- max_audios: An optional maximum number of audio clips to log
599
- sample_rate: The sample rate of the audio clip
600
- log_spec: If set, also log the spectrogram
601
- n_fft_ms: FFT size, in milliseconds
602
- hop_length_ms: The FFT hop length, in milliseconds
603
- channel_select_mode: How to select the channel if the audio is
604
- stereo; can be "first", "last", or "mean"; this is only used
605
- for the spectrogram
606
- spec_sep: An optional separation amount between adjacent
607
- spectrograms
608
- keep_resolution: If set, keep the resolution of the
609
- spectrogram; otherwise, make human-viewable
610
- """
611
- namespace = self.resolve_namespace(namespace)
612
-
613
- @functools.lru_cache(maxsize=None)
614
- def raw_audio_future() -> Array:
615
- raise NotImplementedError
616
-
617
- @functools.lru_cache(maxsize=None)
618
- def audio_future() -> tuple[Array, int]:
619
- raise NotImplementedError
620
-
621
- self.audio[namespace][key] = audio_future
622
-
623
- if log_spec:
624
- # Using a unique key for the spectrogram is very important because
625
- # otherwise Tensorboard will have some issues.
626
- self.log_spectrograms(
627
- key=f"{key}_spec",
628
- value=raw_audio_future,
629
- namespace=namespace,
630
- max_audios=max_audios,
631
- sample_rate=sample_rate,
632
- n_fft_ms=n_fft_ms,
633
- hop_length_ms=hop_length_ms,
634
- channel_select_mode=channel_select_mode,
635
- spec_sep=spec_sep,
636
- keep_resolution=keep_resolution,
637
- )
638
-
639
- def log_spectrogram(
640
- self,
641
- key: str,
642
- value: Callable[[], Array] | Array,
643
- *,
644
- namespace: str | None = None,
645
- sample_rate: int = 44100,
646
- n_fft_ms: float = 32.0,
647
- hop_length_ms: float | None = None,
648
- channel_select_mode: ChannelSelectMode = "first",
649
- keep_resolution: bool = False,
650
- ) -> None:
651
- """Logs spectrograms of an audio clip.
652
-
653
- Args:
654
- key: The key being logged
655
- value: The audio clip being logged; can be (C, T) or (T) as
656
- a mono (1 channel) or stereo (2 channel) audio clip
657
- namespace: An optional logging namespace
658
- sample_rate: The sample rate of the audio clip
659
- n_fft_ms: FFT size, in milliseconds
660
- hop_length_ms: The FFT hop length, in milliseconds
661
- channel_select_mode: How to select the channel if the audio is
662
- stereo; can be "first", "last", or "mean"; this is only used
663
- for the spectrogram
664
- keep_resolution: If set, keep the resolution of the
665
- spectrogram; otherwise, make human-viewable
666
- """
667
- namespace = self.resolve_namespace(namespace)
668
-
669
- @functools.lru_cache(maxsize=None)
670
- def spec_future() -> Array:
671
- raise NotImplementedError
672
-
673
- self.images[namespace][key] = spec_future
674
-
675
- def log_spectrograms(
676
- self,
677
- key: str,
678
- value: Callable[[], Array] | Array,
679
- *,
680
- namespace: str | None = None,
681
- max_audios: int | None = None,
682
- sample_rate: int = 44100,
683
- n_fft_ms: float = 32.0,
684
- hop_length_ms: float | None = None,
685
- channel_select_mode: ChannelSelectMode = "first",
686
- spec_sep: int = 0,
687
- keep_resolution: bool = False,
688
- ) -> None:
689
- """Logs spectrograms of audio clips.
690
-
691
- Args:
692
- key: The key being logged
693
- value: The audio clip being logged; can be (B, C, T) or (B, T) as
694
- a mono (1 channel) or stereo (2 channel) audio clip, with
695
- exactly B clips
696
- namespace: An optional logging namespace
697
- max_audios: An optional maximum number of audio clips to log
698
- sample_rate: The sample rate of the audio clip
699
- n_fft_ms: FFT size, in milliseconds
700
- hop_length_ms: The FFT hop length, in milliseconds
701
- channel_select_mode: How to select the channel if the audio is
702
- stereo; can be "first", "last", or "mean"; this is only used
703
- for the spectrogram
704
- spec_sep: An optional separation amount between adjacent
705
- spectrograms
706
- keep_resolution: If set, keep the resolution of the
707
- spectrogram; otherwise, make human-viewable
708
- """
709
- namespace = self.resolve_namespace(namespace)
710
-
711
- @functools.lru_cache(maxsize=None)
712
- def spec_future() -> Array:
713
- raise NotImplementedError
714
-
715
- self.images[namespace][key] = spec_future
716
-
717
- def log_video(
718
- self,
719
- key: str,
720
- value: Callable[[], Array] | Array,
721
- *,
722
- namespace: str | None = None,
723
- fps: int | None = None,
724
- length: float | None = None,
725
- ) -> None:
726
- """Logs a video.
727
-
728
- Args:
729
- key: The key being logged
730
- value: The video being logged; the video can be (T, C, H, W),
731
- (T, H, W, C) or (T, H, W) as an RGB (3 channel) or grayscale
732
- (1 channel) video
733
- namespace: An optional logging namespace
734
- fps: The video frames per second
735
- length: The desired video length, in seconds, at the target FPS
736
- """
737
- namespace = self.resolve_namespace(namespace)
738
-
739
- @functools.lru_cache(maxsize=None)
740
- def video_future() -> Array:
741
- raise NotImplementedError
742
-
743
- self.videos[namespace][key] = video_future
744
-
745
- def log_videos(
746
- self,
747
- key: str,
748
- value: Callable[[], Array | list[Array]] | Array | list[Array],
749
- *,
750
- namespace: str | None = None,
751
- max_videos: int | None = None,
752
- sep: int = 0,
753
- fps: int | None = None,
754
- length: int | None = None,
755
- ) -> None:
756
- """Logs a set of video.
757
-
758
- Args:
759
- key: The key being logged
760
- value: The videos being logged; the video can be (B, T, C, H, W),
761
- (B, T, H, W, C) or (B T, H, W) as an RGB (3 channel) or
762
- grayscale (1 channel) video
763
- namespace: An optional logging namespace
764
- max_videos: The maximum number of videos to show; extra images
765
- are clipped
766
- sep: An optional separation amount between adjacent videos
767
- fps: The video frames per second
768
- length: The desired video length, in seconds, at the target FPS
769
736
  """
737
+ if not self.active:
738
+ raise RuntimeError("The logger is not active")
770
739
  namespace = self.resolve_namespace(namespace)
771
740
 
772
741
  @functools.lru_cache(maxsize=None)
773
- def videos_future() -> Array:
774
- raise NotImplementedError
775
-
776
- self.videos[namespace][key] = videos_future
777
-
778
- def log_histogram(
779
- self,
780
- key: str,
781
- value: Callable[[], Array] | Array,
782
- *,
783
- namespace: str | None = None,
784
- ) -> None:
785
- """Logs a histogram.
786
-
787
- Args:
788
- key: The key being logged
789
- value: The values to create a histogram from, with arbitrary shape
790
- namespace: An optional logging namespace
791
- """
792
- namespace = self.resolve_namespace(namespace)
742
+ def images_future() -> PILImage:
743
+ images, labels = value() if callable(value) else value
744
+ if max_images is not None:
745
+ images = images[:max_images]
746
+ labels = labels[:max_images]
747
+ images = [get_image(image, target_resolution) for image in images]
748
+ images = [
749
+ image_with_text(
750
+ image,
751
+ standardize_text(label, max_line_length),
752
+ max_num_lines=max_num_lines,
753
+ line_spacing=line_spacing,
754
+ centered=centered,
755
+ )
756
+ for image, label in zip(images, labels)
757
+ ]
758
+ return tile_images(images, sep)
793
759
 
794
- @functools.lru_cache(maxsize=None)
795
- def histogram_future() -> Array:
796
- raise NotImplementedError
797
-
798
- self.histograms[namespace][key] = histogram_future
799
-
800
- def log_point_cloud(
801
- self,
802
- key: str,
803
- value: Callable[[], Array] | Array,
804
- *,
805
- namespace: str | None = None,
806
- max_points: int = 1000,
807
- colors: Callable[[], Array] | Array | None = None,
808
- ) -> None:
809
- """Logs a point cloud.
810
-
811
- Args:
812
- key: The key being logged
813
- value: The point cloud values, with shape (N, 3) or (B, ..., 3);
814
- can pass multiple batches in order to show multiple point
815
- clouds
816
- namespace: An optional logging namespace
817
- max_points: An optional maximum number of points in the point cloud
818
- colors: An optional color for each point, with the same shape as
819
- the point cloud
820
- """
821
- namespace = self.resolve_namespace(namespace)
822
-
823
- @functools.lru_cache(maxsize=None)
824
- def point_cloud_future() -> tuple[Array, Array | None]:
825
- raise NotImplementedError
826
-
827
- self.point_clouds[namespace][key] = point_cloud_future
760
+ self.images[namespace][key] = images_future
828
761
 
829
762
  def log_git_state(self, git_state: str) -> None:
830
763
  for logger in self.loggers:
@@ -839,6 +772,7 @@ class Logger:
839
772
  logger.log_config(config)
840
773
 
841
774
  def __enter__(self) -> Self:
775
+ self.active = True
842
776
  for logger in self.loggers:
843
777
  logger.start()
844
778
  return self
@@ -846,3 +780,4 @@ class Logger:
846
780
  def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
847
781
  for logger in self.loggers:
848
782
  logger.stop()
783
+ self.active = False