xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/logger.py CHANGED
@@ -11,16 +11,21 @@ captions to images, and so on.
11
11
 
12
12
  import functools
13
13
  import logging
14
+ import math
15
+ import re
14
16
  import time
15
17
  from abc import ABC, abstractmethod
16
18
  from collections import defaultdict
17
19
  from dataclasses import dataclass
18
20
  from types import TracebackType
19
- from typing import Callable, Literal, Self, Sequence, TypeVar, get_args
21
+ from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
20
22
 
23
+ import jax
24
+ import jax.numpy as jnp
21
25
  import numpy as np
22
26
  from jaxtyping import Array
23
- from omegaconf import DictConfig
27
+ from PIL import Image, ImageDraw, ImageFont
28
+ from PIL.Image import Image as PILImage
24
29
 
25
30
  from xax.core.state import Phase, State
26
31
  from xax.utils.experiments import IntervalTicker
@@ -34,15 +39,58 @@ Number = int | float | Array | np.ndarray
34
39
 
35
40
  ChannelSelectMode = Literal["first", "last", "mean"]
36
41
 
37
- VALID_VIDEO_CHANNEL_COUNTS = {1, 3}
38
- VALID_AUDIO_CHANNEL_COUNTS = {1, 2}
39
- TARGET_FPS = 12
40
42
  DEFAULT_NAMESPACE = "value"
41
43
 
42
-
43
44
  NAMESPACE_STACK: list[str] = []
44
45
 
45
46
 
47
+ def standardize_text(text: str, max_line_length: int | None = None, remove_non_ascii: bool = False) -> list[str]:
48
+ """Standardizes a text string to a list of lines.
49
+
50
+ Args:
51
+ text: The text to standardize
52
+ max_line_length: If set, truncate lines to this length
53
+ remove_non_ascii: Remove non-ASCII characters if present
54
+
55
+ Returns:
56
+ The standardized text lines
57
+ """
58
+
59
+ def _chunk_lines(text: str, max_length: int) -> Iterator[str]:
60
+ for i in range(0, len(text), max_length):
61
+ yield text[i : i + max_length]
62
+
63
+ if remove_non_ascii:
64
+ text = "".join(char for char in text if ord(char) < 128)
65
+ lines = [re.sub(r"\s+", " ", line) for line in re.split(r"[\n\r]+", text.strip())]
66
+ if max_line_length is not None:
67
+ lines = [subline for line in lines for subline in _chunk_lines(line, max_line_length)]
68
+ return lines
69
+
70
+
71
+ def make_human_viewable_resolution(
72
+ image: PILImage,
73
+ interpolation: Image.Resampling = Image.Resampling.LANCZOS,
74
+ trg_res: tuple[int, int] = (512, 512),
75
+ ) -> PILImage:
76
+ """Resizes image to human-viewable resolution.
77
+
78
+ Args:
79
+ image: The image to resize, with shape (C, H, W)
80
+ interpolation: Interpolation mode to use for image resizing
81
+ trg_res: The target image resolution; the image will be reshaped to
82
+ have approximately the same area as an image with this resolution
83
+
84
+ Returns:
85
+ The resized image
86
+ """
87
+ width, height = image.size
88
+ trg_height, trg_width = trg_res
89
+ factor = math.sqrt((trg_height * trg_width) / (height * width))
90
+ new_height, new_width = int(height * factor), int(width * factor)
91
+ return image.resize((new_width, new_height), interpolation)
92
+
93
+
46
94
  class namespace_context: # noqa: N801
47
95
  def __init__(self, name: str | None) -> None:
48
96
  self._name = name
@@ -62,45 +110,134 @@ class namespace_context: # noqa: N801
62
110
  NAMESPACE_STACK.pop()
63
111
 
64
112
 
65
- @dataclass
66
- class LogImage:
67
- pixels: Array
113
+ def normalize(x: np.ndarray) -> np.ndarray:
114
+ return (x - x.min()) / (x.max() - x.min())
115
+
116
+
117
+ def ternary_search_optimal_side_counts(height: int, width: int, count: int) -> tuple[int, int]:
118
+ min_factors = [i for i in range(1, math.ceil(math.sqrt(count)) + 1) if count % i == 0]
119
+ max_factors = [i for i in min_factors[::-1] if i * i != count]
120
+ factors = [(i, count // i) for i in min_factors] + [(count // i, i) for i in max_factors]
121
+
122
+ lo, hi = 0, len(factors) - 1
123
+
124
+ def penalty(i: int) -> float:
125
+ hval, wval = factors[i]
126
+ h, w = hval * height, wval * width
127
+ return -(min(h, w) ** 2)
128
+
129
+ # Runs ternary search to minimize penalty.
130
+ while lo < hi - 2:
131
+ lmid, rmid = (lo * 2 + hi) // 3, (lo + hi * 2) // 3
132
+ if penalty(lmid) > penalty(rmid):
133
+ lo = lmid
134
+ else:
135
+ hi = rmid
136
+
137
+ # Returns the lowest-penalty configuration.
138
+ mid = (lo + hi) // 2
139
+ plo, pmid, phi = penalty(lo), penalty(mid), penalty(hi)
140
+
141
+ if pmid <= plo and pmid <= phi:
142
+ return factors[mid]
143
+ elif plo <= phi:
144
+ return factors[lo]
145
+ else:
146
+ return factors[hi]
147
+
148
+
149
+ def tile_images_different_sizes(images: list[PILImage], sep: int) -> PILImage:
150
+ """Tiles a list of images into a single image, even if they have different sizes.
151
+
152
+ Args:
153
+ images: The images to tile.
154
+ sep: The separation between adjacent images.
155
+
156
+ Returns:
157
+ The tiled image.
158
+ """
159
+ total_width, max_height = sum(image.width for image in images), max(image.height for image in images)
160
+ tiled = Image.new("RGB", (total_width + (len(images) - 1) * sep, max_height))
161
+ x = 0
162
+ for image in images:
163
+ tiled.paste(image, (x, 0))
164
+ x += image.width + sep
165
+ return tiled
68
166
 
69
167
 
70
- @dataclass
71
- class LogAudio:
72
- frames: Array
73
- sample_rate: int
168
+ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
169
+ """Tiles a list of images into a single image.
74
170
 
171
+ Args:
172
+ images: The images to tile.
173
+ sep: The separation between adjacent images.
75
174
 
76
- @dataclass
175
+ Returns:
176
+ The tiled image.
177
+ """
178
+ if not images:
179
+ return Image.new("RGB", (0, 0))
180
+
181
+ # Gets the optimal side counts.
182
+ height, width = images[0].height, images[0].width
183
+ if not all(image.size == images[0].size for image in images):
184
+ return tile_images_different_sizes(images, sep)
185
+
186
+ hside, wside = ternary_search_optimal_side_counts(height, width, len(images))
187
+
188
+ # Tiles the images.
189
+ tiled = Image.new("RGB", (wside * width + (wside - 1) * sep, hside * height + (hside - 1) * sep))
190
+ for i, image in enumerate(images):
191
+ x, y = i % wside, i // wside
192
+ tiled.paste(image, (x * (width + sep), y * (height + sep)))
193
+
194
+ return tiled
195
+
196
+
197
+ def as_numpy(array: Array) -> np.ndarray:
198
+ array = jax.device_get(array)
199
+ if jax.dtypes.issubdtype(array.dtype, jnp.floating):
200
+ array = array.astype(jnp.float32)
201
+ elif jax.dtypes.issubdtype(array.dtype, jnp.integer):
202
+ array = array.astype(jnp.int32)
203
+ elif jax.dtypes.issubdtype(array.dtype, jnp.bool_):
204
+ array = array.astype(jnp.bool_)
205
+ return np.array(array)
206
+
207
+
208
+ @dataclass(kw_only=True)
209
+ class LogImage:
210
+ image: PILImage
211
+
212
+
213
+ @dataclass(kw_only=True)
77
214
  class LogVideo:
78
- frames: Array
215
+ """Container for video data and metadata.
79
216
 
217
+ Attributes:
218
+ frames: Video frames as a numpy array of shape (T,H,W,C)
219
+ fps: Frames per second
220
+ """
80
221
 
81
- @dataclass
82
- class LogPointCloud:
83
- xyz: Array
84
- colors: Array | None
222
+ frames: np.ndarray
223
+ fps: int
85
224
 
86
225
 
87
- @dataclass
226
+ @dataclass(kw_only=True)
88
227
  class LogLine:
89
228
  state: State
90
229
  scalars: dict[str, dict[str, Number]]
91
230
  strings: dict[str, dict[str, str]]
92
231
  images: dict[str, dict[str, LogImage]]
93
- audios: dict[str, dict[str, LogAudio]]
94
232
  videos: dict[str, dict[str, LogVideo]]
95
- point_cloud: dict[str, dict[str, LogPointCloud]]
96
233
 
97
234
 
98
- @dataclass
235
+ @dataclass(kw_only=True)
99
236
  class LogErrorSummary:
100
237
  message: str
101
238
 
102
239
 
103
- @dataclass
240
+ @dataclass(kw_only=True)
104
241
  class LogError:
105
242
  message: str
106
243
  location: str | None = None
@@ -113,7 +250,7 @@ class LogError:
113
250
  return message
114
251
 
115
252
 
116
- @dataclass
253
+ @dataclass(kw_only=True)
117
254
  class LogStatus:
118
255
  message: str
119
256
  created: float
@@ -121,7 +258,7 @@ class LogStatus:
121
258
  lineno: int | None = None
122
259
 
123
260
 
124
- @dataclass
261
+ @dataclass(kw_only=True)
125
262
  class LogPing:
126
263
  message: str
127
264
  created: float
@@ -129,6 +266,120 @@ class LogPing:
129
266
  lineno: int | None = None
130
267
 
131
268
 
269
+ def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int, int] | None = None) -> LogImage:
270
+ if not isinstance(image, (np.ndarray, Array, PILImage)):
271
+ raise ValueError(f"Unsupported image type: {type(image)}")
272
+ if isinstance(image, Array):
273
+ image = as_numpy(image)
274
+ if isinstance(image, np.ndarray):
275
+ if image.ndim == 2:
276
+ image = np.expand_dims(image, axis=-1)
277
+ if image.ndim != 3:
278
+ raise RuntimeError(f"Expected image to have shape HW, HWC, or CHW, got {image.shape}")
279
+
280
+ # Normalizes the image and converts to integer.
281
+ if np.issubdtype(image.dtype, np.floating):
282
+ image = (normalize(image) * 255).round().astype(np.uint8)
283
+ elif image.dtype == np.uint8:
284
+ pass
285
+ else:
286
+ raise ValueError(f"Unsupported image dtype: {image.dtype}")
287
+
288
+ # Converts to a PIL image.
289
+ if image.shape[-1] == 1:
290
+ image = Image.fromarray(image[..., 0])
291
+ elif image.shape[-1] == 3:
292
+ image = Image.fromarray(image)
293
+ elif image.shape[0] == 1:
294
+ image = Image.fromarray(image[0])
295
+ elif image.shape[0] == 3:
296
+ image = Image.fromarray(image.transpose(1, 2, 0))
297
+ else:
298
+ raise ValueError(f"Unsupported image shape: {image.shape}")
299
+
300
+ if target_resolution is not None:
301
+ image = make_human_viewable_resolution(image, trg_res=target_resolution)
302
+ return LogImage(image=image)
303
+
304
+
305
+ def image_with_text(
306
+ image: PILImage,
307
+ text: list[str],
308
+ max_num_lines: int | None,
309
+ line_spacing: int,
310
+ centered: bool,
311
+ ) -> LogImage:
312
+ """Adds a text label to an image.
313
+
314
+ Args:
315
+ image: The image to label, with shape (C, H, W)
316
+ text: The text label for the image
317
+ max_num_lines: The number of lines of spacing to add to the bottom
318
+ of the image
319
+ line_spacing: The spacing between adjacent lines
320
+ centered: If set, center the text labels, otherwise align to the left
321
+
322
+ Returns:
323
+ The image with a text label
324
+ """
325
+ if not text:
326
+ return LogImage(image=image)
327
+ if max_num_lines is None:
328
+ max_num_lines = len(text)
329
+ else:
330
+ text = text[:max_num_lines]
331
+ width, height = image.size
332
+ font: ImageFont.ImageFont = ImageFont.load_default()
333
+ _, _, _, line_height = font.getbbox(text[0])
334
+ new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
335
+ padded_image = Image.new(image.mode, (new_width, new_height), 255)
336
+ padded_image.paste(image, (0, 0))
337
+ drawer = ImageDraw.Draw(padded_image)
338
+ for i, text_line in enumerate(text):
339
+ text_line_top = height + line_spacing + i * (line_height + line_spacing)
340
+ if centered:
341
+ _, _, line_width, _ = font.getbbox(text_line)
342
+ text_line_left = (width - line_width) / 2
343
+ drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
344
+ else:
345
+ drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
346
+ return LogImage(image=padded_image)
347
+
348
+
349
+ def get_video(video: np.ndarray | Array, fps: int = 30) -> LogVideo:
350
+ """Converts video data to standard format.
351
+
352
+ Args:
353
+ video: The video frames. Can be:
354
+ - A numpy array of shape (T, H, W, C) or (T, C, H, W)
355
+ - A JAX array of shape (T, H, W, C) or (T, C, H, W)
356
+ fps: Frames per second
357
+
358
+ Returns:
359
+ LogVideo containing standardized video frames
360
+ """
361
+ if isinstance(video, Array):
362
+ video = as_numpy(video)
363
+
364
+ if not isinstance(video, np.ndarray):
365
+ raise ValueError(f"Unsupported video type: {type(video)}")
366
+
367
+ # Handle different dimension orderings
368
+ if video.ndim != 4:
369
+ raise ValueError(f"Expected video array of shape (T, H, W, C) or (T, C, H, W), got shape {video.shape}")
370
+
371
+ if video.shape[1] == 3: # (T,C,H,W) format
372
+ video = video.transpose(0, 2, 3, 1)
373
+
374
+ # Normalize and convert to uint8 if needed
375
+ if np.issubdtype(video.dtype, np.floating):
376
+ video = (normalize(video) * 255).round().astype(np.uint8)
377
+ elif video.dtype != np.uint8:
378
+ raise ValueError(f"Unsupported video dtype: {video.dtype}")
379
+
380
+ return LogVideo(frames=video, fps=fps)
381
+
382
+
132
383
  class LoggerImpl(ABC):
133
384
  def __init__(self, log_interval_seconds: float = 1.0) -> None:
134
385
  """Defines some default behavior for loggers.
@@ -187,25 +438,12 @@ class LoggerImpl(ABC):
187
438
  ping: The ping to write.
188
439
  """
189
440
 
190
- def log_git_state(self, git_state: str) -> None:
191
- """Logs Git state for the current run.
192
-
193
- Args:
194
- git_state: The Git state, as text blocks.
195
- """
196
-
197
- def log_training_code(self, training_code: str) -> None:
198
- """Logs the training script code.
441
+ def log_file(self, name: str, contents: str) -> None:
442
+ """Logs a large text file.
199
443
 
200
444
  Args:
201
- training_code: The training script code.
202
- """
203
-
204
- def log_config(self, config: DictConfig) -> None:
205
- """Logs the configuration for the current run.
206
-
207
- Args:
208
- config: The configuration, as a DictConfig.
445
+ name: The name of the file.
446
+ contents: The contents of the file.
209
447
  """
210
448
 
211
449
  def should_log(self, state: State) -> bool:
@@ -260,11 +498,8 @@ class Logger:
260
498
  def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
261
499
  self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
262
500
  self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
263
- self.images: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
264
- self.audio: dict[str, dict[str, Callable[[], tuple[Array, int]]]] = defaultdict(dict)
265
- self.videos: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
266
- self.histograms: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
267
- self.point_clouds: dict[str, dict[str, Callable[[], tuple[Array, Array | None]]]] = defaultdict(dict)
501
+ self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
502
+ self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
268
503
  self.default_namespace = default_namespace
269
504
  self.loggers: list[LoggerImpl] = []
270
505
 
@@ -272,6 +507,9 @@ class Logger:
272
507
  root_logger = logging.getLogger()
273
508
  ToastHandler(self).add_for_logger(root_logger)
274
509
 
510
+ # Flag when the logger is active.
511
+ self.active = False
512
+
275
513
  def add_logger(self, *logger: LoggerImpl) -> None:
276
514
  """Add the logger, so that it gets called when `write` is called.
277
515
 
@@ -285,20 +523,15 @@ class Logger:
285
523
  state=state,
286
524
  scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
287
525
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
288
- images={k: {kk: LogImage(v()) for kk, v in v.items()} for k, v in self.images.items()},
289
- audios={k: {kk: LogAudio(*v()) for kk, v in v.items()} for k, v in self.audio.items()},
290
- videos={k: {kk: LogVideo(v()) for kk, v in v.items()} for k, v in self.videos.items()},
291
- point_cloud={k: {kk: LogPointCloud(*v()) for kk, v in v.items()} for k, v in self.point_clouds.items()},
526
+ images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
527
+ videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
292
528
  )
293
529
 
294
530
  def clear(self) -> None:
295
531
  self.scalars.clear()
296
532
  self.strings.clear()
297
533
  self.images.clear()
298
- self.audio.clear()
299
534
  self.videos.clear()
300
- self.histograms.clear()
301
- self.point_clouds.clear()
302
535
 
303
536
  def write(self, state: State) -> None:
304
537
  """Writes the current step's logging information.
@@ -317,11 +550,11 @@ class Logger:
317
550
 
318
551
  def write_error_summary(self, error_summary: str) -> None:
319
552
  for logger in self.loggers:
320
- logger.write_error_summary(LogErrorSummary(error_summary))
553
+ logger.write_error_summary(LogErrorSummary(message=error_summary))
321
554
 
322
555
  def write_error(self, message: str, location: str | None = None) -> None:
323
556
  for logger in self.loggers:
324
- logger.write_error(LogError(message, location))
557
+ logger.write_error(LogError(message=message, location=location))
325
558
 
326
559
  def write_status(
327
560
  self,
@@ -330,7 +563,12 @@ class Logger:
330
563
  lineno: int | None = None,
331
564
  created: float | None = None,
332
565
  ) -> None:
333
- status = LogStatus(message, time.time() if created is None else created, filename, lineno)
566
+ status = LogStatus(
567
+ message=message,
568
+ created=time.time() if created is None else created,
569
+ filename=filename,
570
+ lineno=lineno,
571
+ )
334
572
  for logger in self.loggers:
335
573
  logger.write_status(status)
336
574
 
@@ -341,7 +579,12 @@ class Logger:
341
579
  lineno: int | None = None,
342
580
  created: float | None = None,
343
581
  ) -> None:
344
- ping = LogPing(message, time.time() if created is None else created, filename, lineno)
582
+ ping = LogPing(
583
+ message=message,
584
+ created=time.time() if created is None else created,
585
+ filename=filename,
586
+ lineno=lineno,
587
+ )
345
588
  for logger in self.loggers:
346
589
  logger.write_ping(ping)
347
590
 
@@ -356,8 +599,13 @@ class Logger:
356
599
  value: The scalar value being logged
357
600
  namespace: An optional logging namespace
358
601
  """
602
+ if not self.active:
603
+ raise RuntimeError("The logger is not active")
359
604
  namespace = self.resolve_namespace(namespace)
360
605
 
606
+ if isinstance(value, jnp.ndarray):
607
+ assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
608
+
361
609
  @functools.lru_cache(maxsize=None)
362
610
  def scalar_future() -> Number:
363
611
  return value() if callable(value) else value
@@ -372,6 +620,8 @@ class Logger:
372
620
  value: The string value being logged
373
621
  namespace: An optional logging namespace
374
622
  """
623
+ if not self.active:
624
+ raise RuntimeError("The logger is not active")
375
625
  namespace = self.resolve_namespace(namespace)
376
626
 
377
627
  @functools.lru_cache(maxsize=None)
@@ -383,71 +633,87 @@ class Logger:
383
633
  def log_image(
384
634
  self,
385
635
  key: str,
386
- value: Callable[[], Array] | Array,
636
+ value: Callable[[], np.ndarray | Array | PILImage] | np.ndarray | Array | PILImage,
387
637
  *,
388
638
  namespace: str | None = None,
389
- keep_resolution: bool = False,
639
+ target_resolution: tuple[int, int] | None = (512, 512),
390
640
  ) -> None:
391
641
  """Logs an image.
392
642
 
393
643
  Args:
394
644
  key: The key being logged
395
- value: The image being logged; can be (C, H, W), (H, W, C) or (H, W)
396
- as an RGB (3 channel) or grayscale (1 channel) image
645
+ value: The image being logged
397
646
  namespace: An optional logging namespace
398
- keep_resolution: If set, keep the image resolution the same,
399
- otherwise upscale or downscale the image to a standard
400
- resolution
647
+ target_resolution: The target resolution for each image; if None,
648
+ don't resample the images
401
649
  """
650
+ if not self.active:
651
+ raise RuntimeError("The logger is not active")
402
652
  namespace = self.resolve_namespace(namespace)
403
653
 
404
654
  @functools.lru_cache(maxsize=None)
405
- def image_future() -> Array:
406
- raise NotImplementedError
655
+ def image_future() -> LogImage:
656
+ return get_image(value() if callable(value) else value, target_resolution)
407
657
 
408
658
  self.images[namespace][key] = image_future
409
659
 
410
660
  def log_labeled_image(
411
661
  self,
412
662
  key: str,
413
- value: Callable[[], tuple[Array, str]] | tuple[Array, str],
663
+ value: Callable[[], tuple[np.ndarray | Array | PILImage, str]] | tuple[np.ndarray | Array | PILImage, str],
414
664
  *,
415
665
  namespace: str | None = None,
416
666
  max_line_length: int | None = None,
417
- keep_resolution: bool = False,
667
+ max_num_lines: int | None = None,
668
+ target_resolution: tuple[int, int] | None = (512, 512),
669
+ line_spacing: int = 2,
418
670
  centered: bool = True,
419
671
  ) -> None:
420
672
  """Logs an image with a label.
421
673
 
422
674
  Args:
423
675
  key: The key being logged
424
- value: The image and label being logged; the image can be (C, H, W),
425
- (H, W, C) or (H, W) as an RGB (3 channel) or grayscale
426
- (1 channel) image
676
+ value: The image and label being logged
427
677
  namespace: An optional logging namespace
428
- max_line_length: Labels longer than this length are wrapped around
429
- keep_resolution: If set, keep the image resolution the same,
430
- otherwise upscale or downscale the image to a standard
431
- resolution
432
- centered: If set, center the text labels, otherwise align to the
433
- left
678
+ max_line_length: The maximum line length for the label
679
+ max_num_lines: The number of lines of spacing to add to the bottom
680
+ of the image
681
+ target_resolution: The target resolution for each image; if None,
682
+ don't resample the images
683
+ line_spacing: The spacing between adjacent lines
684
+ centered: If set, center the text labels, otherwise align to the left
434
685
  """
686
+ if not self.active:
687
+ raise RuntimeError("The logger is not active")
435
688
  namespace = self.resolve_namespace(namespace)
436
689
 
437
690
  @functools.lru_cache(maxsize=None)
438
- def labeled_image_future() -> Array:
439
- raise NotImplementedError
691
+ def image_future() -> LogImage:
692
+ image, label = value() if callable(value) else value
693
+ image = get_image(image, target_resolution)
694
+ return image_with_text(
695
+ image.image,
696
+ standardize_text(label, max_line_length),
697
+ max_num_lines=max_num_lines,
698
+ line_spacing=line_spacing,
699
+ centered=centered,
700
+ )
440
701
 
441
- self.images[namespace][key] = labeled_image_future
702
+ self.images[namespace][key] = image_future
442
703
 
443
704
  def log_images(
444
705
  self,
445
706
  key: str,
446
- value: Callable[[], Array] | Array,
707
+ value: (
708
+ Callable[[], Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array]
709
+ | Sequence[np.ndarray | Array | PILImage]
710
+ | np.ndarray
711
+ | Array
712
+ ),
447
713
  *,
448
714
  namespace: str | None = None,
449
- keep_resolution: bool = False,
450
715
  max_images: int | None = None,
716
+ target_resolution: tuple[int, int] | None = (256, 256),
451
717
  sep: int = 0,
452
718
  ) -> None:
453
719
  """Logs a set of images.
@@ -456,35 +722,49 @@ class Logger:
456
722
 
457
723
  Args:
458
724
  key: The key being logged
459
- value: The images being logged; can be (B, C, H, W), (B, H, W, C)
460
- or (B H, W) as an RGB (3 channel) or grayscale (1 channel) image
725
+ value: The images being logged
461
726
  namespace: An optional logging namespace
462
- keep_resolution: If set, keep the image resolution the same,
463
- otherwise upscale or downscale the image to a standard
464
- resolution
465
727
  max_images: The maximum number of images to show; extra images
466
728
  are clipped
729
+ target_resolution: The target resolution for each image; if None,
730
+ don't resample the images
467
731
  sep: An optional separation amount between adjacent images
468
732
  """
733
+ if not self.active:
734
+ raise RuntimeError("The logger is not active")
469
735
  namespace = self.resolve_namespace(namespace)
470
736
 
471
737
  @functools.lru_cache(maxsize=None)
472
- def images_future() -> Array:
473
- raise NotImplementedError
738
+ def images_future() -> LogImage:
739
+ images = value() if callable(value) else value
740
+ if max_images is not None:
741
+ images = images[:max_images]
742
+ if isinstance(images, Array):
743
+ images = as_numpy(images)
744
+ if isinstance(images, Sequence):
745
+ images = list(images)
746
+ images = [get_image(image, target_resolution) for image in images]
747
+ tiled = tile_images([img.image for img in images], sep)
748
+ return LogImage(image=tiled)
474
749
 
475
750
  self.images[namespace][key] = images_future
476
751
 
477
752
  def log_labeled_images(
478
753
  self,
479
754
  key: str,
480
- value: Callable[[], tuple[Array, Sequence[str]]] | tuple[Array, Sequence[str]],
755
+ value: (
756
+ Callable[[], tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]]
757
+ | tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]
758
+ ),
481
759
  *,
482
760
  namespace: str | None = None,
483
- max_line_length: int | None = None,
484
- keep_resolution: bool = False,
485
761
  max_images: int | None = None,
486
- sep: int = 0,
762
+ max_line_length: int | None = None,
763
+ max_num_lines: int | None = None,
764
+ target_resolution: tuple[int, int] | None = (256, 256),
765
+ line_spacing: int = 2,
487
766
  centered: bool = True,
767
+ sep: int = 0,
488
768
  ) -> None:
489
769
  """Logs a set of images with labels.
490
770
 
@@ -492,353 +772,79 @@ class Logger:
492
772
 
493
773
  Args:
494
774
  key: The key being logged
495
- value: The images and labels being logged; images can be
496
- (B, C, H, W), (B, H, W, C) or (B, H, W) as an RGB (3 channel)
497
- or grayscale (1 channel) image, with exactly B labels
775
+ value: The images and labels being logged
498
776
  namespace: An optional logging namespace
499
- max_line_length: Labels longer than this length are wrapped around
500
- keep_resolution: If set, keep the image resolution the same,
501
- otherwise upscale or downscale the image to a standard
502
- resolution
503
777
  max_images: The maximum number of images to show; extra images
504
778
  are clipped
779
+ max_line_length: The maximum line length for the label
780
+ max_num_lines: The number of lines of spacing to add to the bottom
781
+ of the image
782
+ target_resolution: The target resolution for each image; if None,
783
+ don't resample the images
784
+ line_spacing: The spacing between adjacent lines
785
+ centered: If set, center the text labels, otherwise align to the left
505
786
  sep: An optional separation amount between adjacent images
506
- centered: If set, center the text labels, otherwise align to the
507
- left
508
- """
509
- namespace = self.resolve_namespace(namespace)
510
-
511
- @functools.lru_cache(maxsize=None)
512
- def labeled_images_future() -> Array:
513
- raise NotImplementedError
514
-
515
- self.images[namespace][key] = labeled_images_future
516
-
517
- def log_audio(
518
- self,
519
- key: str,
520
- value: Callable[[], Array] | Array,
521
- *,
522
- namespace: str | None = None,
523
- sample_rate: int = 44100,
524
- log_spec: bool = False,
525
- n_fft_ms: float = 32.0,
526
- hop_length_ms: float | None = None,
527
- channel_select_mode: ChannelSelectMode = "first",
528
- keep_resolution: bool = False,
529
- ) -> None:
530
- """Logs an audio clip.
531
-
532
- Args:
533
- key: The key being logged
534
- value: The audio clip being logged; can be (C, T) or (T) as
535
- a mono (1 channel) or stereo (2 channel) audio clip
536
- namespace: An optional logging namespace
537
- sample_rate: The sample rate of the audio clip
538
- log_spec: If set, also log the spectrogram
539
- n_fft_ms: FFT size, in milliseconds
540
- hop_length_ms: The FFT hop length, in milliseconds
541
- channel_select_mode: How to select the channel if the audio is
542
- stereo; can be "first", "last", or "mean"; this is only used
543
- for the spectrogram
544
- keep_resolution: If set, keep the resolution of the
545
- spectrogram; otherwise, make human-viewable
546
787
  """
788
+ if not self.active:
789
+ raise RuntimeError("The logger is not active")
547
790
  namespace = self.resolve_namespace(namespace)
548
791
 
549
792
  @functools.lru_cache(maxsize=None)
550
- def raw_audio_future() -> Array:
551
- raise NotImplementedError
793
+ def images_future() -> LogImage:
794
+ images, labels = value() if callable(value) else value
795
+ if max_images is not None:
796
+ images = images[:max_images]
797
+ labels = labels[:max_images]
798
+ images = [get_image(image, target_resolution) for image in images]
799
+ labeled = [
800
+ image_with_text(
801
+ img.image,
802
+ standardize_text(label, max_line_length),
803
+ max_num_lines=max_num_lines,
804
+ line_spacing=line_spacing,
805
+ centered=centered,
806
+ )
807
+ for img, label in zip(images, labels)
808
+ ]
809
+ tiled = tile_images([img.image for img in labeled], sep)
810
+ return LogImage(image=tiled)
552
811
 
553
- @functools.lru_cache(maxsize=None)
554
- def audio_future() -> tuple[Array, int]:
555
- raise NotImplementedError
556
-
557
- self.audio[namespace][key] = audio_future
558
-
559
- if log_spec:
560
- # Using a unique key for the spectrogram is very important because
561
- # otherwise Tensorboard will have some issues.
562
- self.log_spectrogram(
563
- key=f"{key}_spec",
564
- value=raw_audio_future,
565
- namespace=namespace,
566
- sample_rate=sample_rate,
567
- n_fft_ms=n_fft_ms,
568
- hop_length_ms=hop_length_ms,
569
- channel_select_mode=channel_select_mode,
570
- keep_resolution=keep_resolution,
571
- )
572
-
573
- def log_audios(
574
- self,
575
- key: str,
576
- value: Callable[[], Array] | Array,
577
- *,
578
- namespace: str | None = None,
579
- sep_ms: float = 0.0,
580
- max_audios: int | None = None,
581
- sample_rate: int = 44100,
582
- log_spec: bool = False,
583
- n_fft_ms: float = 32.0,
584
- hop_length_ms: float | None = None,
585
- channel_select_mode: ChannelSelectMode = "first",
586
- spec_sep: int = 0,
587
- keep_resolution: bool = False,
588
- ) -> None:
589
- """Logs multiple audio clips.
590
-
591
- Args:
592
- key: The key being logged
593
- value: The audio clip being logged; can be (B, C, T) or (B, T) as
594
- a mono (1 channel) or stereo (2 channel) audio clip, with
595
- exactly B clips
596
- namespace: An optional logging namespace
597
- sep_ms: An optional separation amount between adjacent audio clips
598
- max_audios: An optional maximum number of audio clips to log
599
- sample_rate: The sample rate of the audio clip
600
- log_spec: If set, also log the spectrogram
601
- n_fft_ms: FFT size, in milliseconds
602
- hop_length_ms: The FFT hop length, in milliseconds
603
- channel_select_mode: How to select the channel if the audio is
604
- stereo; can be "first", "last", or "mean"; this is only used
605
- for the spectrogram
606
- spec_sep: An optional separation amount between adjacent
607
- spectrograms
608
- keep_resolution: If set, keep the resolution of the
609
- spectrogram; otherwise, make human-viewable
610
- """
611
- namespace = self.resolve_namespace(namespace)
612
-
613
- @functools.lru_cache(maxsize=None)
614
- def raw_audio_future() -> Array:
615
- raise NotImplementedError
616
-
617
- @functools.lru_cache(maxsize=None)
618
- def audio_future() -> tuple[Array, int]:
619
- raise NotImplementedError
620
-
621
- self.audio[namespace][key] = audio_future
622
-
623
- if log_spec:
624
- # Using a unique key for the spectrogram is very important because
625
- # otherwise Tensorboard will have some issues.
626
- self.log_spectrograms(
627
- key=f"{key}_spec",
628
- value=raw_audio_future,
629
- namespace=namespace,
630
- max_audios=max_audios,
631
- sample_rate=sample_rate,
632
- n_fft_ms=n_fft_ms,
633
- hop_length_ms=hop_length_ms,
634
- channel_select_mode=channel_select_mode,
635
- spec_sep=spec_sep,
636
- keep_resolution=keep_resolution,
637
- )
638
-
639
- def log_spectrogram(
640
- self,
641
- key: str,
642
- value: Callable[[], Array] | Array,
643
- *,
644
- namespace: str | None = None,
645
- sample_rate: int = 44100,
646
- n_fft_ms: float = 32.0,
647
- hop_length_ms: float | None = None,
648
- channel_select_mode: ChannelSelectMode = "first",
649
- keep_resolution: bool = False,
650
- ) -> None:
651
- """Logs spectrograms of an audio clip.
652
-
653
- Args:
654
- key: The key being logged
655
- value: The audio clip being logged; can be (C, T) or (T) as
656
- a mono (1 channel) or stereo (2 channel) audio clip
657
- namespace: An optional logging namespace
658
- sample_rate: The sample rate of the audio clip
659
- n_fft_ms: FFT size, in milliseconds
660
- hop_length_ms: The FFT hop length, in milliseconds
661
- channel_select_mode: How to select the channel if the audio is
662
- stereo; can be "first", "last", or "mean"; this is only used
663
- for the spectrogram
664
- keep_resolution: If set, keep the resolution of the
665
- spectrogram; otherwise, make human-viewable
666
- """
667
- namespace = self.resolve_namespace(namespace)
668
-
669
- @functools.lru_cache(maxsize=None)
670
- def spec_future() -> Array:
671
- raise NotImplementedError
672
-
673
- self.images[namespace][key] = spec_future
674
-
675
- def log_spectrograms(
676
- self,
677
- key: str,
678
- value: Callable[[], Array] | Array,
679
- *,
680
- namespace: str | None = None,
681
- max_audios: int | None = None,
682
- sample_rate: int = 44100,
683
- n_fft_ms: float = 32.0,
684
- hop_length_ms: float | None = None,
685
- channel_select_mode: ChannelSelectMode = "first",
686
- spec_sep: int = 0,
687
- keep_resolution: bool = False,
688
- ) -> None:
689
- """Logs spectrograms of audio clips.
690
-
691
- Args:
692
- key: The key being logged
693
- value: The audio clip being logged; can be (B, C, T) or (B, T) as
694
- a mono (1 channel) or stereo (2 channel) audio clip, with
695
- exactly B clips
696
- namespace: An optional logging namespace
697
- max_audios: An optional maximum number of audio clips to log
698
- sample_rate: The sample rate of the audio clip
699
- n_fft_ms: FFT size, in milliseconds
700
- hop_length_ms: The FFT hop length, in milliseconds
701
- channel_select_mode: How to select the channel if the audio is
702
- stereo; can be "first", "last", or "mean"; this is only used
703
- for the spectrogram
704
- spec_sep: An optional separation amount between adjacent
705
- spectrograms
706
- keep_resolution: If set, keep the resolution of the
707
- spectrogram; otherwise, make human-viewable
708
- """
709
- namespace = self.resolve_namespace(namespace)
710
-
711
- @functools.lru_cache(maxsize=None)
712
- def spec_future() -> Array:
713
- raise NotImplementedError
812
+ self.images[namespace][key] = images_future
714
813
 
715
- self.images[namespace][key] = spec_future
814
+ def log_file(self, name: str, contents: str) -> None:
815
+ for logger in self.loggers:
816
+ logger.log_file(name, contents)
716
817
 
717
818
  def log_video(
718
819
  self,
719
820
  key: str,
720
- value: Callable[[], Array] | Array,
821
+ value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
721
822
  *,
823
+ fps: int = 30,
722
824
  namespace: str | None = None,
723
- fps: int | None = None,
724
- length: float | None = None,
725
825
  ) -> None:
726
826
  """Logs a video.
727
827
 
728
828
  Args:
729
829
  key: The key being logged
730
- value: The video being logged; the video can be (T, C, H, W),
731
- (T, H, W, C) or (T, H, W) as an RGB (3 channel) or grayscale
732
- (1 channel) video
830
+ value: The video frames. Can be:
831
+ - A numpy array of shape (T,H,W,C) or (T,C,H,W)
832
+ - A JAX array of shape (T,H,W,C) or (T,C,H,W)
833
+ fps: Frames per second
733
834
  namespace: An optional logging namespace
734
- fps: The video frames per second
735
- length: The desired video length, in seconds, at the target FPS
736
835
  """
836
+ if not self.active:
837
+ raise RuntimeError("The logger is not active")
737
838
  namespace = self.resolve_namespace(namespace)
738
839
 
739
840
  @functools.lru_cache(maxsize=None)
740
- def video_future() -> Array:
741
- raise NotImplementedError
841
+ def video_future() -> LogVideo:
842
+ return get_video(value() if callable(value) else value, fps=fps)
742
843
 
743
844
  self.videos[namespace][key] = video_future
744
845
 
745
- def log_videos(
746
- self,
747
- key: str,
748
- value: Callable[[], Array | list[Array]] | Array | list[Array],
749
- *,
750
- namespace: str | None = None,
751
- max_videos: int | None = None,
752
- sep: int = 0,
753
- fps: int | None = None,
754
- length: int | None = None,
755
- ) -> None:
756
- """Logs a set of video.
757
-
758
- Args:
759
- key: The key being logged
760
- value: The videos being logged; the video can be (B, T, C, H, W),
761
- (B, T, H, W, C) or (B T, H, W) as an RGB (3 channel) or
762
- grayscale (1 channel) video
763
- namespace: An optional logging namespace
764
- max_videos: The maximum number of videos to show; extra images
765
- are clipped
766
- sep: An optional separation amount between adjacent videos
767
- fps: The video frames per second
768
- length: The desired video length, in seconds, at the target FPS
769
- """
770
- namespace = self.resolve_namespace(namespace)
771
-
772
- @functools.lru_cache(maxsize=None)
773
- def videos_future() -> Array:
774
- raise NotImplementedError
775
-
776
- self.videos[namespace][key] = videos_future
777
-
778
- def log_histogram(
779
- self,
780
- key: str,
781
- value: Callable[[], Array] | Array,
782
- *,
783
- namespace: str | None = None,
784
- ) -> None:
785
- """Logs a histogram.
786
-
787
- Args:
788
- key: The key being logged
789
- value: The values to create a histogram from, with arbitrary shape
790
- namespace: An optional logging namespace
791
- """
792
- namespace = self.resolve_namespace(namespace)
793
-
794
- @functools.lru_cache(maxsize=None)
795
- def histogram_future() -> Array:
796
- raise NotImplementedError
797
-
798
- self.histograms[namespace][key] = histogram_future
799
-
800
- def log_point_cloud(
801
- self,
802
- key: str,
803
- value: Callable[[], Array] | Array,
804
- *,
805
- namespace: str | None = None,
806
- max_points: int = 1000,
807
- colors: Callable[[], Array] | Array | None = None,
808
- ) -> None:
809
- """Logs a point cloud.
810
-
811
- Args:
812
- key: The key being logged
813
- value: The point cloud values, with shape (N, 3) or (B, ..., 3);
814
- can pass multiple batches in order to show multiple point
815
- clouds
816
- namespace: An optional logging namespace
817
- max_points: An optional maximum number of points in the point cloud
818
- colors: An optional color for each point, with the same shape as
819
- the point cloud
820
- """
821
- namespace = self.resolve_namespace(namespace)
822
-
823
- @functools.lru_cache(maxsize=None)
824
- def point_cloud_future() -> tuple[Array, Array | None]:
825
- raise NotImplementedError
826
-
827
- self.point_clouds[namespace][key] = point_cloud_future
828
-
829
- def log_git_state(self, git_state: str) -> None:
830
- for logger in self.loggers:
831
- logger.log_git_state(git_state)
832
-
833
- def log_training_code(self, training_code: str) -> None:
834
- for logger in self.loggers:
835
- logger.log_training_code(training_code)
836
-
837
- def log_config(self, config: DictConfig) -> None:
838
- for logger in self.loggers:
839
- logger.log_config(config)
840
-
841
846
  def __enter__(self) -> Self:
847
+ self.active = True
842
848
  for logger in self.loggers:
843
849
  logger.start()
844
850
  return self
@@ -846,3 +852,4 @@ class Logger:
846
852
  def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
847
853
  for logger in self.loggers:
848
854
  logger.stop()
855
+ self.active = False