xax 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +49 -7
- xax/core/conf.py +1 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +15 -10
- xax/task/base.py +0 -6
- xax/task/logger.py +328 -393
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/tensorboard.py +2 -5
- xax/task/mixins/__init__.py +2 -1
- xax/task/mixins/artifacts.py +14 -7
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +10 -10
- xax/task/mixins/data_loader.py +6 -9
- xax/task/mixins/gpu_stats.py +3 -3
- xax/task/mixins/logger.py +2 -250
- xax/task/mixins/process.py +4 -0
- xax/task/mixins/train.py +71 -40
- xax/task/task.py +6 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +49 -29
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/METADATA +15 -14
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/top_level.txt +0 -0
xax/task/logger.py
CHANGED
@@ -11,16 +11,22 @@ captions to images, and so on.
|
|
11
11
|
|
12
12
|
import functools
|
13
13
|
import logging
|
14
|
+
import math
|
15
|
+
import re
|
14
16
|
import time
|
15
17
|
from abc import ABC, abstractmethod
|
16
18
|
from collections import defaultdict
|
17
19
|
from dataclasses import dataclass
|
18
20
|
from types import TracebackType
|
19
|
-
from typing import Callable, Literal, Self, Sequence, TypeVar, get_args
|
21
|
+
from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
|
20
22
|
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
21
25
|
import numpy as np
|
22
26
|
from jaxtyping import Array
|
23
27
|
from omegaconf import DictConfig
|
28
|
+
from PIL import Image, ImageDraw, ImageFont
|
29
|
+
from PIL.Image import Image as PILImage
|
24
30
|
|
25
31
|
from xax.core.state import Phase, State
|
26
32
|
from xax.utils.experiments import IntervalTicker
|
@@ -34,15 +40,102 @@ Number = int | float | Array | np.ndarray
|
|
34
40
|
|
35
41
|
ChannelSelectMode = Literal["first", "last", "mean"]
|
36
42
|
|
37
|
-
VALID_VIDEO_CHANNEL_COUNTS = {1, 3}
|
38
|
-
VALID_AUDIO_CHANNEL_COUNTS = {1, 2}
|
39
|
-
TARGET_FPS = 12
|
40
43
|
DEFAULT_NAMESPACE = "value"
|
41
44
|
|
42
|
-
|
43
45
|
NAMESPACE_STACK: list[str] = []
|
44
46
|
|
45
47
|
|
48
|
+
def standardize_text(text: str, max_line_length: int | None = None, remove_non_ascii: bool = False) -> list[str]:
|
49
|
+
"""Standardizes a text string to a list of lines.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
text: The text to standardize
|
53
|
+
max_line_length: If set, truncate lines to this length
|
54
|
+
remove_non_ascii: Remove non-ASCII characters if present
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
The standardized text lines
|
58
|
+
"""
|
59
|
+
|
60
|
+
def _chunk_lines(text: str, max_length: int) -> Iterator[str]:
|
61
|
+
for i in range(0, len(text), max_length):
|
62
|
+
yield text[i : i + max_length]
|
63
|
+
|
64
|
+
if remove_non_ascii:
|
65
|
+
text = "".join(char for char in text if ord(char) < 128)
|
66
|
+
lines = [re.sub(r"\s+", " ", line) for line in re.split(r"[\n\r]+", text.strip())]
|
67
|
+
if max_line_length is not None:
|
68
|
+
lines = [subline for line in lines for subline in _chunk_lines(line, max_line_length)]
|
69
|
+
return lines
|
70
|
+
|
71
|
+
|
72
|
+
def make_human_viewable_resolution(
|
73
|
+
image: PILImage,
|
74
|
+
interpolation: Image.Resampling = Image.Resampling.LANCZOS,
|
75
|
+
trg_res: tuple[int, int] = (512, 512),
|
76
|
+
) -> PILImage:
|
77
|
+
"""Resizes image to human-viewable resolution.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
image: The image to resize, with shape (C, H, W)
|
81
|
+
interpolation: Interpolation mode to use for image resizing
|
82
|
+
trg_res: The target image resolution; the image will be reshaped to
|
83
|
+
have approximately the same area as an image with this resolution
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
The resized image
|
87
|
+
"""
|
88
|
+
width, height = image.size
|
89
|
+
trg_height, trg_width = trg_res
|
90
|
+
factor = math.sqrt((trg_height * trg_width) / (height * width))
|
91
|
+
new_height, new_width = int(height * factor), int(width * factor)
|
92
|
+
return image.resize((new_width, new_height), interpolation)
|
93
|
+
|
94
|
+
|
95
|
+
def image_with_text(
|
96
|
+
image: PILImage,
|
97
|
+
text: list[str],
|
98
|
+
max_num_lines: int | None,
|
99
|
+
line_spacing: int,
|
100
|
+
centered: bool,
|
101
|
+
) -> PILImage:
|
102
|
+
"""Adds a text label to an image.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
image: The image to label, with shape (C, H, W)
|
106
|
+
text: The text label for the image
|
107
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
108
|
+
of the image
|
109
|
+
line_spacing: The spacing between adjacent lines
|
110
|
+
centered: If set, center the text labels, otherwise align to the left
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
The image with a text label
|
114
|
+
"""
|
115
|
+
if not text:
|
116
|
+
return image
|
117
|
+
if max_num_lines is None:
|
118
|
+
max_num_lines = len(text)
|
119
|
+
else:
|
120
|
+
text = text[:max_num_lines]
|
121
|
+
width, height = image.size
|
122
|
+
font: ImageFont.ImageFont = ImageFont.load_default()
|
123
|
+
_, _, _, line_height = font.getbbox(text[0])
|
124
|
+
new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
|
125
|
+
padded_image = Image.new(image.mode, (new_width, new_height), 255)
|
126
|
+
padded_image.paste(image, (0, 0))
|
127
|
+
drawer = ImageDraw.Draw(padded_image)
|
128
|
+
for i, text_line in enumerate(text):
|
129
|
+
text_line_top = height + line_spacing + i * (line_height + line_spacing)
|
130
|
+
if centered:
|
131
|
+
_, _, line_width, _ = font.getbbox(text_line)
|
132
|
+
text_line_left = (width - line_width) / 2
|
133
|
+
drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
|
134
|
+
else:
|
135
|
+
drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
|
136
|
+
return padded_image
|
137
|
+
|
138
|
+
|
46
139
|
class namespace_context: # noqa: N801
|
47
140
|
def __init__(self, name: str | None) -> None:
|
48
141
|
self._name = name
|
@@ -62,26 +155,140 @@ class namespace_context: # noqa: N801
|
|
62
155
|
NAMESPACE_STACK.pop()
|
63
156
|
|
64
157
|
|
65
|
-
|
66
|
-
|
67
|
-
pixels: Array
|
158
|
+
def normalize(x: np.ndarray) -> np.ndarray:
|
159
|
+
return (x - x.min()) / (x.max() - x.min())
|
68
160
|
|
69
161
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
162
|
+
def ternary_search_optimal_side_counts(height: int, width: int, count: int) -> tuple[int, int]:
|
163
|
+
min_factors = [i for i in range(1, math.ceil(math.sqrt(count)) + 1) if count % i == 0]
|
164
|
+
max_factors = [i for i in min_factors[::-1] if i * i != count]
|
165
|
+
factors = [(i, count // i) for i in min_factors] + [(count // i, i) for i in max_factors]
|
74
166
|
|
167
|
+
lo, hi = 0, len(factors) - 1
|
75
168
|
|
76
|
-
|
77
|
-
|
78
|
-
|
169
|
+
def penalty(i: int) -> float:
|
170
|
+
hval, wval = factors[i]
|
171
|
+
h, w = hval * height, wval * width
|
172
|
+
return -(min(h, w) ** 2)
|
173
|
+
|
174
|
+
# Runs ternary search to minimize penalty.
|
175
|
+
while lo < hi - 2:
|
176
|
+
lmid, rmid = (lo * 2 + hi) // 3, (lo + hi * 2) // 3
|
177
|
+
if penalty(lmid) > penalty(rmid):
|
178
|
+
lo = lmid
|
179
|
+
else:
|
180
|
+
hi = rmid
|
181
|
+
|
182
|
+
# Returns the lowest-penalty configuration.
|
183
|
+
mid = (lo + hi) // 2
|
184
|
+
plo, pmid, phi = penalty(lo), penalty(mid), penalty(hi)
|
185
|
+
|
186
|
+
if pmid <= plo and pmid <= phi:
|
187
|
+
return factors[mid]
|
188
|
+
elif plo <= phi:
|
189
|
+
return factors[lo]
|
190
|
+
else:
|
191
|
+
return factors[hi]
|
192
|
+
|
193
|
+
|
194
|
+
def tile_images_different_sizes(images: list[PILImage], sep: int) -> PILImage:
|
195
|
+
"""Tiles a list of images into a single image, even if they have different sizes.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
images: The images to tile.
|
199
|
+
sep: The separation between adjacent images.
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
The tiled image.
|
203
|
+
"""
|
204
|
+
total_width, max_height = sum(image.width for image in images), max(image.height for image in images)
|
205
|
+
tiled = Image.new("RGB", (total_width + (len(images) - 1) * sep, max_height))
|
206
|
+
x = 0
|
207
|
+
for image in images:
|
208
|
+
tiled.paste(image, (x, 0))
|
209
|
+
x += image.width + sep
|
210
|
+
return tiled
|
211
|
+
|
212
|
+
|
213
|
+
def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
|
214
|
+
"""Tiles a list of images into a single image.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
images: The images to tile.
|
218
|
+
sep: The separation between adjacent images.
|
219
|
+
|
220
|
+
Returns:
|
221
|
+
The tiled image.
|
222
|
+
"""
|
223
|
+
if not images:
|
224
|
+
return Image.new("RGB", (0, 0))
|
225
|
+
|
226
|
+
# Gets the optimal side counts.
|
227
|
+
height, width = images[0].height, images[0].width
|
228
|
+
if not all(image.size == images[0].size for image in images):
|
229
|
+
return tile_images_different_sizes(images, sep)
|
230
|
+
|
231
|
+
hside, wside = ternary_search_optimal_side_counts(height, width, len(images))
|
232
|
+
|
233
|
+
# Tiles the images.
|
234
|
+
tiled = Image.new("RGB", (wside * width + (wside - 1) * sep, hside * height + (hside - 1) * sep))
|
235
|
+
for i, image in enumerate(images):
|
236
|
+
x, y = i % wside, i // wside
|
237
|
+
tiled.paste(image, (x * (width + sep), y * (height + sep)))
|
238
|
+
|
239
|
+
return tiled
|
240
|
+
|
241
|
+
|
242
|
+
def as_numpy(array: Array) -> np.ndarray:
|
243
|
+
array = jax.device_get(array)
|
244
|
+
if jax.dtypes.issubdtype(array.dtype, jnp.floating):
|
245
|
+
array = array.astype(jnp.float32)
|
246
|
+
elif jax.dtypes.issubdtype(array.dtype, jnp.integer):
|
247
|
+
array = array.astype(jnp.int32)
|
248
|
+
elif jax.dtypes.issubdtype(array.dtype, jnp.bool_):
|
249
|
+
array = array.astype(jnp.bool_)
|
250
|
+
return np.array(array)
|
251
|
+
|
252
|
+
|
253
|
+
def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int, int] | None = None) -> PILImage:
|
254
|
+
if not isinstance(image, (np.ndarray, Array, PILImage)):
|
255
|
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
256
|
+
if isinstance(image, Array):
|
257
|
+
image = as_numpy(image)
|
258
|
+
if isinstance(image, np.ndarray):
|
259
|
+
if image.ndim == 2:
|
260
|
+
image = np.expand_dims(image, axis=-1)
|
261
|
+
if image.ndim != 3:
|
262
|
+
raise RuntimeError(f"Expected image to have shape HW, HWC, or CHW, got {image.shape}")
|
263
|
+
|
264
|
+
# Normalizes the image and converts to integer.
|
265
|
+
if np.issubdtype(image.dtype, np.floating):
|
266
|
+
image = (normalize(image) * 255).round().astype(np.uint8)
|
267
|
+
elif image.dtype == np.uint8:
|
268
|
+
pass
|
269
|
+
else:
|
270
|
+
raise ValueError(f"Unsupported image dtype: {image.dtype}")
|
271
|
+
|
272
|
+
# Converts to a PIL image.
|
273
|
+
if image.shape[-1] == 1:
|
274
|
+
image = Image.fromarray(image[..., 0])
|
275
|
+
elif image.shape[-1] == 3:
|
276
|
+
image = Image.fromarray(image)
|
277
|
+
elif image.shape[0] == 1:
|
278
|
+
image = Image.fromarray(image[0])
|
279
|
+
elif image.shape[0] == 3:
|
280
|
+
image = Image.fromarray(image.transpose(1, 2, 0))
|
281
|
+
else:
|
282
|
+
raise ValueError(f"Unsupported image shape: {image.shape}")
|
283
|
+
|
284
|
+
if target_resolution is not None:
|
285
|
+
image = make_human_viewable_resolution(image, trg_res=target_resolution)
|
286
|
+
return image
|
79
287
|
|
80
288
|
|
81
289
|
@dataclass
|
82
|
-
class
|
83
|
-
|
84
|
-
colors: Array | None
|
290
|
+
class LogImage:
|
291
|
+
image: PILImage
|
85
292
|
|
86
293
|
|
87
294
|
@dataclass
|
@@ -90,9 +297,6 @@ class LogLine:
|
|
90
297
|
scalars: dict[str, dict[str, Number]]
|
91
298
|
strings: dict[str, dict[str, str]]
|
92
299
|
images: dict[str, dict[str, LogImage]]
|
93
|
-
audios: dict[str, dict[str, LogAudio]]
|
94
|
-
videos: dict[str, dict[str, LogVideo]]
|
95
|
-
point_cloud: dict[str, dict[str, LogPointCloud]]
|
96
300
|
|
97
301
|
|
98
302
|
@dataclass
|
@@ -260,11 +464,7 @@ class Logger:
|
|
260
464
|
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
261
465
|
self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
|
262
466
|
self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
|
263
|
-
self.images: dict[str, dict[str, Callable[[],
|
264
|
-
self.audio: dict[str, dict[str, Callable[[], tuple[Array, int]]]] = defaultdict(dict)
|
265
|
-
self.videos: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
|
266
|
-
self.histograms: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
|
267
|
-
self.point_clouds: dict[str, dict[str, Callable[[], tuple[Array, Array | None]]]] = defaultdict(dict)
|
467
|
+
self.images: dict[str, dict[str, Callable[[], PILImage]]] = defaultdict(dict)
|
268
468
|
self.default_namespace = default_namespace
|
269
469
|
self.loggers: list[LoggerImpl] = []
|
270
470
|
|
@@ -272,6 +472,9 @@ class Logger:
|
|
272
472
|
root_logger = logging.getLogger()
|
273
473
|
ToastHandler(self).add_for_logger(root_logger)
|
274
474
|
|
475
|
+
# Flag when the logger is active.
|
476
|
+
self.active = False
|
477
|
+
|
275
478
|
def add_logger(self, *logger: LoggerImpl) -> None:
|
276
479
|
"""Add the logger, so that it gets called when `write` is called.
|
277
480
|
|
@@ -286,19 +489,12 @@ class Logger:
|
|
286
489
|
scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
|
287
490
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
288
491
|
images={k: {kk: LogImage(v()) for kk, v in v.items()} for k, v in self.images.items()},
|
289
|
-
audios={k: {kk: LogAudio(*v()) for kk, v in v.items()} for k, v in self.audio.items()},
|
290
|
-
videos={k: {kk: LogVideo(v()) for kk, v in v.items()} for k, v in self.videos.items()},
|
291
|
-
point_cloud={k: {kk: LogPointCloud(*v()) for kk, v in v.items()} for k, v in self.point_clouds.items()},
|
292
492
|
)
|
293
493
|
|
294
494
|
def clear(self) -> None:
|
295
495
|
self.scalars.clear()
|
296
496
|
self.strings.clear()
|
297
497
|
self.images.clear()
|
298
|
-
self.audio.clear()
|
299
|
-
self.videos.clear()
|
300
|
-
self.histograms.clear()
|
301
|
-
self.point_clouds.clear()
|
302
498
|
|
303
499
|
def write(self, state: State) -> None:
|
304
500
|
"""Writes the current step's logging information.
|
@@ -356,6 +552,8 @@ class Logger:
|
|
356
552
|
value: The scalar value being logged
|
357
553
|
namespace: An optional logging namespace
|
358
554
|
"""
|
555
|
+
if not self.active:
|
556
|
+
raise RuntimeError("The logger is not active")
|
359
557
|
namespace = self.resolve_namespace(namespace)
|
360
558
|
|
361
559
|
@functools.lru_cache(maxsize=None)
|
@@ -372,6 +570,8 @@ class Logger:
|
|
372
570
|
value: The string value being logged
|
373
571
|
namespace: An optional logging namespace
|
374
572
|
"""
|
573
|
+
if not self.active:
|
574
|
+
raise RuntimeError("The logger is not active")
|
375
575
|
namespace = self.resolve_namespace(namespace)
|
376
576
|
|
377
577
|
@functools.lru_cache(maxsize=None)
|
@@ -383,71 +583,87 @@ class Logger:
|
|
383
583
|
def log_image(
|
384
584
|
self,
|
385
585
|
key: str,
|
386
|
-
value: Callable[[], Array] | Array,
|
586
|
+
value: Callable[[], np.ndarray | Array | PILImage] | np.ndarray | Array | PILImage,
|
387
587
|
*,
|
388
588
|
namespace: str | None = None,
|
389
|
-
|
589
|
+
target_resolution: tuple[int, int] | None = (512, 512),
|
390
590
|
) -> None:
|
391
591
|
"""Logs an image.
|
392
592
|
|
393
593
|
Args:
|
394
594
|
key: The key being logged
|
395
|
-
value: The image being logged
|
396
|
-
as an RGB (3 channel) or grayscale (1 channel) image
|
595
|
+
value: The image being logged
|
397
596
|
namespace: An optional logging namespace
|
398
|
-
|
399
|
-
|
400
|
-
resolution
|
597
|
+
target_resolution: The target resolution for each image; if None,
|
598
|
+
don't resample the images
|
401
599
|
"""
|
600
|
+
if not self.active:
|
601
|
+
raise RuntimeError("The logger is not active")
|
402
602
|
namespace = self.resolve_namespace(namespace)
|
403
603
|
|
404
604
|
@functools.lru_cache(maxsize=None)
|
405
|
-
def image_future() ->
|
406
|
-
|
605
|
+
def image_future() -> PILImage:
|
606
|
+
return get_image(value() if callable(value) else value, target_resolution)
|
407
607
|
|
408
608
|
self.images[namespace][key] = image_future
|
409
609
|
|
410
610
|
def log_labeled_image(
|
411
611
|
self,
|
412
612
|
key: str,
|
413
|
-
value: Callable[[], tuple[Array, str]] | tuple[Array, str],
|
613
|
+
value: Callable[[], tuple[np.ndarray | Array | PILImage, str]] | tuple[np.ndarray | Array | PILImage, str],
|
414
614
|
*,
|
415
615
|
namespace: str | None = None,
|
416
616
|
max_line_length: int | None = None,
|
417
|
-
|
617
|
+
max_num_lines: int | None = None,
|
618
|
+
target_resolution: tuple[int, int] | None = (512, 512),
|
619
|
+
line_spacing: int = 2,
|
418
620
|
centered: bool = True,
|
419
621
|
) -> None:
|
420
622
|
"""Logs an image with a label.
|
421
623
|
|
422
624
|
Args:
|
423
625
|
key: The key being logged
|
424
|
-
value: The image and label being logged
|
425
|
-
(H, W, C) or (H, W) as an RGB (3 channel) or grayscale
|
426
|
-
(1 channel) image
|
626
|
+
value: The image and label being logged
|
427
627
|
namespace: An optional logging namespace
|
428
|
-
max_line_length:
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
628
|
+
max_line_length: The maximum line length for the label
|
629
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
630
|
+
of the image
|
631
|
+
target_resolution: The target resolution for each image; if None,
|
632
|
+
don't resample the images
|
633
|
+
line_spacing: The spacing between adjacent lines
|
634
|
+
centered: If set, center the text labels, otherwise align to the left
|
434
635
|
"""
|
636
|
+
if not self.active:
|
637
|
+
raise RuntimeError("The logger is not active")
|
435
638
|
namespace = self.resolve_namespace(namespace)
|
436
639
|
|
437
640
|
@functools.lru_cache(maxsize=None)
|
438
|
-
def
|
439
|
-
|
641
|
+
def image_future() -> PILImage:
|
642
|
+
image, label = value() if callable(value) else value
|
643
|
+
image = get_image(image, target_resolution)
|
644
|
+
return image_with_text(
|
645
|
+
image,
|
646
|
+
standardize_text(label, max_line_length),
|
647
|
+
max_num_lines=max_num_lines,
|
648
|
+
line_spacing=line_spacing,
|
649
|
+
centered=centered,
|
650
|
+
)
|
440
651
|
|
441
|
-
self.images[namespace][key] =
|
652
|
+
self.images[namespace][key] = image_future
|
442
653
|
|
443
654
|
def log_images(
|
444
655
|
self,
|
445
656
|
key: str,
|
446
|
-
value:
|
657
|
+
value: (
|
658
|
+
Callable[[], Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array]
|
659
|
+
| Sequence[np.ndarray | Array | PILImage]
|
660
|
+
| np.ndarray
|
661
|
+
| Array
|
662
|
+
),
|
447
663
|
*,
|
448
664
|
namespace: str | None = None,
|
449
|
-
keep_resolution: bool = False,
|
450
665
|
max_images: int | None = None,
|
666
|
+
target_resolution: tuple[int, int] | None = (256, 256),
|
451
667
|
sep: int = 0,
|
452
668
|
) -> None:
|
453
669
|
"""Logs a set of images.
|
@@ -456,35 +672,48 @@ class Logger:
|
|
456
672
|
|
457
673
|
Args:
|
458
674
|
key: The key being logged
|
459
|
-
value: The images being logged
|
460
|
-
or (B H, W) as an RGB (3 channel) or grayscale (1 channel) image
|
675
|
+
value: The images being logged
|
461
676
|
namespace: An optional logging namespace
|
462
|
-
keep_resolution: If set, keep the image resolution the same,
|
463
|
-
otherwise upscale or downscale the image to a standard
|
464
|
-
resolution
|
465
677
|
max_images: The maximum number of images to show; extra images
|
466
678
|
are clipped
|
679
|
+
target_resolution: The target resolution for each image; if None,
|
680
|
+
don't resample the images
|
467
681
|
sep: An optional separation amount between adjacent images
|
468
682
|
"""
|
683
|
+
if not self.active:
|
684
|
+
raise RuntimeError("The logger is not active")
|
469
685
|
namespace = self.resolve_namespace(namespace)
|
470
686
|
|
471
687
|
@functools.lru_cache(maxsize=None)
|
472
|
-
def images_future() ->
|
473
|
-
|
688
|
+
def images_future() -> PILImage:
|
689
|
+
images = value() if callable(value) else value
|
690
|
+
if max_images is not None:
|
691
|
+
images = images[:max_images]
|
692
|
+
if isinstance(images, Array):
|
693
|
+
images = as_numpy(images)
|
694
|
+
if isinstance(images, Sequence):
|
695
|
+
images = list(images)
|
696
|
+
images = [get_image(image, target_resolution) for image in images]
|
697
|
+
return tile_images(images, sep)
|
474
698
|
|
475
699
|
self.images[namespace][key] = images_future
|
476
700
|
|
477
701
|
def log_labeled_images(
|
478
702
|
self,
|
479
703
|
key: str,
|
480
|
-
value:
|
704
|
+
value: (
|
705
|
+
Callable[[], tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]]
|
706
|
+
| tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]
|
707
|
+
),
|
481
708
|
*,
|
482
709
|
namespace: str | None = None,
|
483
|
-
max_line_length: int | None = None,
|
484
|
-
keep_resolution: bool = False,
|
485
710
|
max_images: int | None = None,
|
486
|
-
|
711
|
+
max_line_length: int | None = None,
|
712
|
+
max_num_lines: int | None = None,
|
713
|
+
target_resolution: tuple[int, int] | None = (256, 256),
|
714
|
+
line_spacing: int = 2,
|
487
715
|
centered: bool = True,
|
716
|
+
sep: int = 0,
|
488
717
|
) -> None:
|
489
718
|
"""Logs a set of images with labels.
|
490
719
|
|
@@ -492,339 +721,43 @@ class Logger:
|
|
492
721
|
|
493
722
|
Args:
|
494
723
|
key: The key being logged
|
495
|
-
value: The images and labels being logged
|
496
|
-
(B, C, H, W), (B, H, W, C) or (B, H, W) as an RGB (3 channel)
|
497
|
-
or grayscale (1 channel) image, with exactly B labels
|
724
|
+
value: The images and labels being logged
|
498
725
|
namespace: An optional logging namespace
|
499
|
-
max_line_length: Labels longer than this length are wrapped around
|
500
|
-
keep_resolution: If set, keep the image resolution the same,
|
501
|
-
otherwise upscale or downscale the image to a standard
|
502
|
-
resolution
|
503
726
|
max_images: The maximum number of images to show; extra images
|
504
727
|
are clipped
|
728
|
+
max_line_length: The maximum line length for the label
|
729
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
730
|
+
of the image
|
731
|
+
target_resolution: The target resolution for each image; if None,
|
732
|
+
don't resample the images
|
733
|
+
line_spacing: The spacing between adjacent lines
|
734
|
+
centered: If set, center the text labels, otherwise align to the left
|
505
735
|
sep: An optional separation amount between adjacent images
|
506
|
-
centered: If set, center the text labels, otherwise align to the
|
507
|
-
left
|
508
|
-
"""
|
509
|
-
namespace = self.resolve_namespace(namespace)
|
510
|
-
|
511
|
-
@functools.lru_cache(maxsize=None)
|
512
|
-
def labeled_images_future() -> Array:
|
513
|
-
raise NotImplementedError
|
514
|
-
|
515
|
-
self.images[namespace][key] = labeled_images_future
|
516
|
-
|
517
|
-
def log_audio(
|
518
|
-
self,
|
519
|
-
key: str,
|
520
|
-
value: Callable[[], Array] | Array,
|
521
|
-
*,
|
522
|
-
namespace: str | None = None,
|
523
|
-
sample_rate: int = 44100,
|
524
|
-
log_spec: bool = False,
|
525
|
-
n_fft_ms: float = 32.0,
|
526
|
-
hop_length_ms: float | None = None,
|
527
|
-
channel_select_mode: ChannelSelectMode = "first",
|
528
|
-
keep_resolution: bool = False,
|
529
|
-
) -> None:
|
530
|
-
"""Logs an audio clip.
|
531
|
-
|
532
|
-
Args:
|
533
|
-
key: The key being logged
|
534
|
-
value: The audio clip being logged; can be (C, T) or (T) as
|
535
|
-
a mono (1 channel) or stereo (2 channel) audio clip
|
536
|
-
namespace: An optional logging namespace
|
537
|
-
sample_rate: The sample rate of the audio clip
|
538
|
-
log_spec: If set, also log the spectrogram
|
539
|
-
n_fft_ms: FFT size, in milliseconds
|
540
|
-
hop_length_ms: The FFT hop length, in milliseconds
|
541
|
-
channel_select_mode: How to select the channel if the audio is
|
542
|
-
stereo; can be "first", "last", or "mean"; this is only used
|
543
|
-
for the spectrogram
|
544
|
-
keep_resolution: If set, keep the resolution of the
|
545
|
-
spectrogram; otherwise, make human-viewable
|
546
|
-
"""
|
547
|
-
namespace = self.resolve_namespace(namespace)
|
548
|
-
|
549
|
-
@functools.lru_cache(maxsize=None)
|
550
|
-
def raw_audio_future() -> Array:
|
551
|
-
raise NotImplementedError
|
552
|
-
|
553
|
-
@functools.lru_cache(maxsize=None)
|
554
|
-
def audio_future() -> tuple[Array, int]:
|
555
|
-
raise NotImplementedError
|
556
|
-
|
557
|
-
self.audio[namespace][key] = audio_future
|
558
|
-
|
559
|
-
if log_spec:
|
560
|
-
# Using a unique key for the spectrogram is very important because
|
561
|
-
# otherwise Tensorboard will have some issues.
|
562
|
-
self.log_spectrogram(
|
563
|
-
key=f"{key}_spec",
|
564
|
-
value=raw_audio_future,
|
565
|
-
namespace=namespace,
|
566
|
-
sample_rate=sample_rate,
|
567
|
-
n_fft_ms=n_fft_ms,
|
568
|
-
hop_length_ms=hop_length_ms,
|
569
|
-
channel_select_mode=channel_select_mode,
|
570
|
-
keep_resolution=keep_resolution,
|
571
|
-
)
|
572
|
-
|
573
|
-
def log_audios(
|
574
|
-
self,
|
575
|
-
key: str,
|
576
|
-
value: Callable[[], Array] | Array,
|
577
|
-
*,
|
578
|
-
namespace: str | None = None,
|
579
|
-
sep_ms: float = 0.0,
|
580
|
-
max_audios: int | None = None,
|
581
|
-
sample_rate: int = 44100,
|
582
|
-
log_spec: bool = False,
|
583
|
-
n_fft_ms: float = 32.0,
|
584
|
-
hop_length_ms: float | None = None,
|
585
|
-
channel_select_mode: ChannelSelectMode = "first",
|
586
|
-
spec_sep: int = 0,
|
587
|
-
keep_resolution: bool = False,
|
588
|
-
) -> None:
|
589
|
-
"""Logs multiple audio clips.
|
590
|
-
|
591
|
-
Args:
|
592
|
-
key: The key being logged
|
593
|
-
value: The audio clip being logged; can be (B, C, T) or (B, T) as
|
594
|
-
a mono (1 channel) or stereo (2 channel) audio clip, with
|
595
|
-
exactly B clips
|
596
|
-
namespace: An optional logging namespace
|
597
|
-
sep_ms: An optional separation amount between adjacent audio clips
|
598
|
-
max_audios: An optional maximum number of audio clips to log
|
599
|
-
sample_rate: The sample rate of the audio clip
|
600
|
-
log_spec: If set, also log the spectrogram
|
601
|
-
n_fft_ms: FFT size, in milliseconds
|
602
|
-
hop_length_ms: The FFT hop length, in milliseconds
|
603
|
-
channel_select_mode: How to select the channel if the audio is
|
604
|
-
stereo; can be "first", "last", or "mean"; this is only used
|
605
|
-
for the spectrogram
|
606
|
-
spec_sep: An optional separation amount between adjacent
|
607
|
-
spectrograms
|
608
|
-
keep_resolution: If set, keep the resolution of the
|
609
|
-
spectrogram; otherwise, make human-viewable
|
610
|
-
"""
|
611
|
-
namespace = self.resolve_namespace(namespace)
|
612
|
-
|
613
|
-
@functools.lru_cache(maxsize=None)
|
614
|
-
def raw_audio_future() -> Array:
|
615
|
-
raise NotImplementedError
|
616
|
-
|
617
|
-
@functools.lru_cache(maxsize=None)
|
618
|
-
def audio_future() -> tuple[Array, int]:
|
619
|
-
raise NotImplementedError
|
620
|
-
|
621
|
-
self.audio[namespace][key] = audio_future
|
622
|
-
|
623
|
-
if log_spec:
|
624
|
-
# Using a unique key for the spectrogram is very important because
|
625
|
-
# otherwise Tensorboard will have some issues.
|
626
|
-
self.log_spectrograms(
|
627
|
-
key=f"{key}_spec",
|
628
|
-
value=raw_audio_future,
|
629
|
-
namespace=namespace,
|
630
|
-
max_audios=max_audios,
|
631
|
-
sample_rate=sample_rate,
|
632
|
-
n_fft_ms=n_fft_ms,
|
633
|
-
hop_length_ms=hop_length_ms,
|
634
|
-
channel_select_mode=channel_select_mode,
|
635
|
-
spec_sep=spec_sep,
|
636
|
-
keep_resolution=keep_resolution,
|
637
|
-
)
|
638
|
-
|
639
|
-
def log_spectrogram(
|
640
|
-
self,
|
641
|
-
key: str,
|
642
|
-
value: Callable[[], Array] | Array,
|
643
|
-
*,
|
644
|
-
namespace: str | None = None,
|
645
|
-
sample_rate: int = 44100,
|
646
|
-
n_fft_ms: float = 32.0,
|
647
|
-
hop_length_ms: float | None = None,
|
648
|
-
channel_select_mode: ChannelSelectMode = "first",
|
649
|
-
keep_resolution: bool = False,
|
650
|
-
) -> None:
|
651
|
-
"""Logs spectrograms of an audio clip.
|
652
|
-
|
653
|
-
Args:
|
654
|
-
key: The key being logged
|
655
|
-
value: The audio clip being logged; can be (C, T) or (T) as
|
656
|
-
a mono (1 channel) or stereo (2 channel) audio clip
|
657
|
-
namespace: An optional logging namespace
|
658
|
-
sample_rate: The sample rate of the audio clip
|
659
|
-
n_fft_ms: FFT size, in milliseconds
|
660
|
-
hop_length_ms: The FFT hop length, in milliseconds
|
661
|
-
channel_select_mode: How to select the channel if the audio is
|
662
|
-
stereo; can be "first", "last", or "mean"; this is only used
|
663
|
-
for the spectrogram
|
664
|
-
keep_resolution: If set, keep the resolution of the
|
665
|
-
spectrogram; otherwise, make human-viewable
|
666
|
-
"""
|
667
|
-
namespace = self.resolve_namespace(namespace)
|
668
|
-
|
669
|
-
@functools.lru_cache(maxsize=None)
|
670
|
-
def spec_future() -> Array:
|
671
|
-
raise NotImplementedError
|
672
|
-
|
673
|
-
self.images[namespace][key] = spec_future
|
674
|
-
|
675
|
-
def log_spectrograms(
|
676
|
-
self,
|
677
|
-
key: str,
|
678
|
-
value: Callable[[], Array] | Array,
|
679
|
-
*,
|
680
|
-
namespace: str | None = None,
|
681
|
-
max_audios: int | None = None,
|
682
|
-
sample_rate: int = 44100,
|
683
|
-
n_fft_ms: float = 32.0,
|
684
|
-
hop_length_ms: float | None = None,
|
685
|
-
channel_select_mode: ChannelSelectMode = "first",
|
686
|
-
spec_sep: int = 0,
|
687
|
-
keep_resolution: bool = False,
|
688
|
-
) -> None:
|
689
|
-
"""Logs spectrograms of audio clips.
|
690
|
-
|
691
|
-
Args:
|
692
|
-
key: The key being logged
|
693
|
-
value: The audio clip being logged; can be (B, C, T) or (B, T) as
|
694
|
-
a mono (1 channel) or stereo (2 channel) audio clip, with
|
695
|
-
exactly B clips
|
696
|
-
namespace: An optional logging namespace
|
697
|
-
max_audios: An optional maximum number of audio clips to log
|
698
|
-
sample_rate: The sample rate of the audio clip
|
699
|
-
n_fft_ms: FFT size, in milliseconds
|
700
|
-
hop_length_ms: The FFT hop length, in milliseconds
|
701
|
-
channel_select_mode: How to select the channel if the audio is
|
702
|
-
stereo; can be "first", "last", or "mean"; this is only used
|
703
|
-
for the spectrogram
|
704
|
-
spec_sep: An optional separation amount between adjacent
|
705
|
-
spectrograms
|
706
|
-
keep_resolution: If set, keep the resolution of the
|
707
|
-
spectrogram; otherwise, make human-viewable
|
708
|
-
"""
|
709
|
-
namespace = self.resolve_namespace(namespace)
|
710
|
-
|
711
|
-
@functools.lru_cache(maxsize=None)
|
712
|
-
def spec_future() -> Array:
|
713
|
-
raise NotImplementedError
|
714
|
-
|
715
|
-
self.images[namespace][key] = spec_future
|
716
|
-
|
717
|
-
def log_video(
|
718
|
-
self,
|
719
|
-
key: str,
|
720
|
-
value: Callable[[], Array] | Array,
|
721
|
-
*,
|
722
|
-
namespace: str | None = None,
|
723
|
-
fps: int | None = None,
|
724
|
-
length: float | None = None,
|
725
|
-
) -> None:
|
726
|
-
"""Logs a video.
|
727
|
-
|
728
|
-
Args:
|
729
|
-
key: The key being logged
|
730
|
-
value: The video being logged; the video can be (T, C, H, W),
|
731
|
-
(T, H, W, C) or (T, H, W) as an RGB (3 channel) or grayscale
|
732
|
-
(1 channel) video
|
733
|
-
namespace: An optional logging namespace
|
734
|
-
fps: The video frames per second
|
735
|
-
length: The desired video length, in seconds, at the target FPS
|
736
|
-
"""
|
737
|
-
namespace = self.resolve_namespace(namespace)
|
738
|
-
|
739
|
-
@functools.lru_cache(maxsize=None)
|
740
|
-
def video_future() -> Array:
|
741
|
-
raise NotImplementedError
|
742
|
-
|
743
|
-
self.videos[namespace][key] = video_future
|
744
|
-
|
745
|
-
def log_videos(
|
746
|
-
self,
|
747
|
-
key: str,
|
748
|
-
value: Callable[[], Array | list[Array]] | Array | list[Array],
|
749
|
-
*,
|
750
|
-
namespace: str | None = None,
|
751
|
-
max_videos: int | None = None,
|
752
|
-
sep: int = 0,
|
753
|
-
fps: int | None = None,
|
754
|
-
length: int | None = None,
|
755
|
-
) -> None:
|
756
|
-
"""Logs a set of video.
|
757
|
-
|
758
|
-
Args:
|
759
|
-
key: The key being logged
|
760
|
-
value: The videos being logged; the video can be (B, T, C, H, W),
|
761
|
-
(B, T, H, W, C) or (B T, H, W) as an RGB (3 channel) or
|
762
|
-
grayscale (1 channel) video
|
763
|
-
namespace: An optional logging namespace
|
764
|
-
max_videos: The maximum number of videos to show; extra images
|
765
|
-
are clipped
|
766
|
-
sep: An optional separation amount between adjacent videos
|
767
|
-
fps: The video frames per second
|
768
|
-
length: The desired video length, in seconds, at the target FPS
|
769
736
|
"""
|
737
|
+
if not self.active:
|
738
|
+
raise RuntimeError("The logger is not active")
|
770
739
|
namespace = self.resolve_namespace(namespace)
|
771
740
|
|
772
741
|
@functools.lru_cache(maxsize=None)
|
773
|
-
def
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
namespace: An optional logging namespace
|
791
|
-
"""
|
792
|
-
namespace = self.resolve_namespace(namespace)
|
742
|
+
def images_future() -> PILImage:
|
743
|
+
images, labels = value() if callable(value) else value
|
744
|
+
if max_images is not None:
|
745
|
+
images = images[:max_images]
|
746
|
+
labels = labels[:max_images]
|
747
|
+
images = [get_image(image, target_resolution) for image in images]
|
748
|
+
images = [
|
749
|
+
image_with_text(
|
750
|
+
image,
|
751
|
+
standardize_text(label, max_line_length),
|
752
|
+
max_num_lines=max_num_lines,
|
753
|
+
line_spacing=line_spacing,
|
754
|
+
centered=centered,
|
755
|
+
)
|
756
|
+
for image, label in zip(images, labels)
|
757
|
+
]
|
758
|
+
return tile_images(images, sep)
|
793
759
|
|
794
|
-
|
795
|
-
def histogram_future() -> Array:
|
796
|
-
raise NotImplementedError
|
797
|
-
|
798
|
-
self.histograms[namespace][key] = histogram_future
|
799
|
-
|
800
|
-
def log_point_cloud(
|
801
|
-
self,
|
802
|
-
key: str,
|
803
|
-
value: Callable[[], Array] | Array,
|
804
|
-
*,
|
805
|
-
namespace: str | None = None,
|
806
|
-
max_points: int = 1000,
|
807
|
-
colors: Callable[[], Array] | Array | None = None,
|
808
|
-
) -> None:
|
809
|
-
"""Logs a point cloud.
|
810
|
-
|
811
|
-
Args:
|
812
|
-
key: The key being logged
|
813
|
-
value: The point cloud values, with shape (N, 3) or (B, ..., 3);
|
814
|
-
can pass multiple batches in order to show multiple point
|
815
|
-
clouds
|
816
|
-
namespace: An optional logging namespace
|
817
|
-
max_points: An optional maximum number of points in the point cloud
|
818
|
-
colors: An optional color for each point, with the same shape as
|
819
|
-
the point cloud
|
820
|
-
"""
|
821
|
-
namespace = self.resolve_namespace(namespace)
|
822
|
-
|
823
|
-
@functools.lru_cache(maxsize=None)
|
824
|
-
def point_cloud_future() -> tuple[Array, Array | None]:
|
825
|
-
raise NotImplementedError
|
826
|
-
|
827
|
-
self.point_clouds[namespace][key] = point_cloud_future
|
760
|
+
self.images[namespace][key] = images_future
|
828
761
|
|
829
762
|
def log_git_state(self, git_state: str) -> None:
|
830
763
|
for logger in self.loggers:
|
@@ -839,6 +772,7 @@ class Logger:
|
|
839
772
|
logger.log_config(config)
|
840
773
|
|
841
774
|
def __enter__(self) -> Self:
|
775
|
+
self.active = True
|
842
776
|
for logger in self.loggers:
|
843
777
|
logger.start()
|
844
778
|
return self
|
@@ -846,3 +780,4 @@ class Logger:
|
|
846
780
|
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
|
847
781
|
for logger in self.loggers:
|
848
782
|
logger.stop()
|
783
|
+
self.active = False
|