xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +122 -8
- xax/core/conf.py +9 -33
- xax/core/state.py +13 -23
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +17 -10
- xax/task/base.py +2 -6
- xax/task/logger.py +419 -412
- xax/task/loggers/callback.py +44 -0
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +16 -33
- xax/task/mixins/__init__.py +3 -1
- xax/task/mixins/artifacts.py +19 -9
- xax/task/mixins/checkpointing.py +221 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +26 -15
- xax/task/mixins/data_loader.py +27 -19
- xax/task/mixins/gpu_stats.py +22 -8
- xax/task/mixins/logger.py +5 -251
- xax/task/mixins/process.py +8 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +236 -145
- xax/task/script.py +1 -1
- xax/task/task.py +13 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +89 -21
- xax-0.0.6.dist-info/METADATA +50 -0
- xax-0.0.6.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.3.dist-info/METADATA +0 -39
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/top_level.txt +0 -0
xax/task/logger.py
CHANGED
@@ -11,16 +11,21 @@ captions to images, and so on.
|
|
11
11
|
|
12
12
|
import functools
|
13
13
|
import logging
|
14
|
+
import math
|
15
|
+
import re
|
14
16
|
import time
|
15
17
|
from abc import ABC, abstractmethod
|
16
18
|
from collections import defaultdict
|
17
19
|
from dataclasses import dataclass
|
18
20
|
from types import TracebackType
|
19
|
-
from typing import Callable, Literal, Self, Sequence, TypeVar, get_args
|
21
|
+
from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, get_args
|
20
22
|
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
21
25
|
import numpy as np
|
22
26
|
from jaxtyping import Array
|
23
|
-
from
|
27
|
+
from PIL import Image, ImageDraw, ImageFont
|
28
|
+
from PIL.Image import Image as PILImage
|
24
29
|
|
25
30
|
from xax.core.state import Phase, State
|
26
31
|
from xax.utils.experiments import IntervalTicker
|
@@ -34,15 +39,58 @@ Number = int | float | Array | np.ndarray
|
|
34
39
|
|
35
40
|
ChannelSelectMode = Literal["first", "last", "mean"]
|
36
41
|
|
37
|
-
VALID_VIDEO_CHANNEL_COUNTS = {1, 3}
|
38
|
-
VALID_AUDIO_CHANNEL_COUNTS = {1, 2}
|
39
|
-
TARGET_FPS = 12
|
40
42
|
DEFAULT_NAMESPACE = "value"
|
41
43
|
|
42
|
-
|
43
44
|
NAMESPACE_STACK: list[str] = []
|
44
45
|
|
45
46
|
|
47
|
+
def standardize_text(text: str, max_line_length: int | None = None, remove_non_ascii: bool = False) -> list[str]:
|
48
|
+
"""Standardizes a text string to a list of lines.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
text: The text to standardize
|
52
|
+
max_line_length: If set, truncate lines to this length
|
53
|
+
remove_non_ascii: Remove non-ASCII characters if present
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
The standardized text lines
|
57
|
+
"""
|
58
|
+
|
59
|
+
def _chunk_lines(text: str, max_length: int) -> Iterator[str]:
|
60
|
+
for i in range(0, len(text), max_length):
|
61
|
+
yield text[i : i + max_length]
|
62
|
+
|
63
|
+
if remove_non_ascii:
|
64
|
+
text = "".join(char for char in text if ord(char) < 128)
|
65
|
+
lines = [re.sub(r"\s+", " ", line) for line in re.split(r"[\n\r]+", text.strip())]
|
66
|
+
if max_line_length is not None:
|
67
|
+
lines = [subline for line in lines for subline in _chunk_lines(line, max_line_length)]
|
68
|
+
return lines
|
69
|
+
|
70
|
+
|
71
|
+
def make_human_viewable_resolution(
|
72
|
+
image: PILImage,
|
73
|
+
interpolation: Image.Resampling = Image.Resampling.LANCZOS,
|
74
|
+
trg_res: tuple[int, int] = (512, 512),
|
75
|
+
) -> PILImage:
|
76
|
+
"""Resizes image to human-viewable resolution.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
image: The image to resize, with shape (C, H, W)
|
80
|
+
interpolation: Interpolation mode to use for image resizing
|
81
|
+
trg_res: The target image resolution; the image will be reshaped to
|
82
|
+
have approximately the same area as an image with this resolution
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
The resized image
|
86
|
+
"""
|
87
|
+
width, height = image.size
|
88
|
+
trg_height, trg_width = trg_res
|
89
|
+
factor = math.sqrt((trg_height * trg_width) / (height * width))
|
90
|
+
new_height, new_width = int(height * factor), int(width * factor)
|
91
|
+
return image.resize((new_width, new_height), interpolation)
|
92
|
+
|
93
|
+
|
46
94
|
class namespace_context: # noqa: N801
|
47
95
|
def __init__(self, name: str | None) -> None:
|
48
96
|
self._name = name
|
@@ -62,45 +110,134 @@ class namespace_context: # noqa: N801
|
|
62
110
|
NAMESPACE_STACK.pop()
|
63
111
|
|
64
112
|
|
65
|
-
|
66
|
-
|
67
|
-
|
113
|
+
def normalize(x: np.ndarray) -> np.ndarray:
|
114
|
+
return (x - x.min()) / (x.max() - x.min())
|
115
|
+
|
116
|
+
|
117
|
+
def ternary_search_optimal_side_counts(height: int, width: int, count: int) -> tuple[int, int]:
|
118
|
+
min_factors = [i for i in range(1, math.ceil(math.sqrt(count)) + 1) if count % i == 0]
|
119
|
+
max_factors = [i for i in min_factors[::-1] if i * i != count]
|
120
|
+
factors = [(i, count // i) for i in min_factors] + [(count // i, i) for i in max_factors]
|
121
|
+
|
122
|
+
lo, hi = 0, len(factors) - 1
|
123
|
+
|
124
|
+
def penalty(i: int) -> float:
|
125
|
+
hval, wval = factors[i]
|
126
|
+
h, w = hval * height, wval * width
|
127
|
+
return -(min(h, w) ** 2)
|
128
|
+
|
129
|
+
# Runs ternary search to minimize penalty.
|
130
|
+
while lo < hi - 2:
|
131
|
+
lmid, rmid = (lo * 2 + hi) // 3, (lo + hi * 2) // 3
|
132
|
+
if penalty(lmid) > penalty(rmid):
|
133
|
+
lo = lmid
|
134
|
+
else:
|
135
|
+
hi = rmid
|
136
|
+
|
137
|
+
# Returns the lowest-penalty configuration.
|
138
|
+
mid = (lo + hi) // 2
|
139
|
+
plo, pmid, phi = penalty(lo), penalty(mid), penalty(hi)
|
140
|
+
|
141
|
+
if pmid <= plo and pmid <= phi:
|
142
|
+
return factors[mid]
|
143
|
+
elif plo <= phi:
|
144
|
+
return factors[lo]
|
145
|
+
else:
|
146
|
+
return factors[hi]
|
147
|
+
|
148
|
+
|
149
|
+
def tile_images_different_sizes(images: list[PILImage], sep: int) -> PILImage:
|
150
|
+
"""Tiles a list of images into a single image, even if they have different sizes.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
images: The images to tile.
|
154
|
+
sep: The separation between adjacent images.
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
The tiled image.
|
158
|
+
"""
|
159
|
+
total_width, max_height = sum(image.width for image in images), max(image.height for image in images)
|
160
|
+
tiled = Image.new("RGB", (total_width + (len(images) - 1) * sep, max_height))
|
161
|
+
x = 0
|
162
|
+
for image in images:
|
163
|
+
tiled.paste(image, (x, 0))
|
164
|
+
x += image.width + sep
|
165
|
+
return tiled
|
68
166
|
|
69
167
|
|
70
|
-
|
71
|
-
|
72
|
-
frames: Array
|
73
|
-
sample_rate: int
|
168
|
+
def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
|
169
|
+
"""Tiles a list of images into a single image.
|
74
170
|
|
171
|
+
Args:
|
172
|
+
images: The images to tile.
|
173
|
+
sep: The separation between adjacent images.
|
75
174
|
|
76
|
-
|
175
|
+
Returns:
|
176
|
+
The tiled image.
|
177
|
+
"""
|
178
|
+
if not images:
|
179
|
+
return Image.new("RGB", (0, 0))
|
180
|
+
|
181
|
+
# Gets the optimal side counts.
|
182
|
+
height, width = images[0].height, images[0].width
|
183
|
+
if not all(image.size == images[0].size for image in images):
|
184
|
+
return tile_images_different_sizes(images, sep)
|
185
|
+
|
186
|
+
hside, wside = ternary_search_optimal_side_counts(height, width, len(images))
|
187
|
+
|
188
|
+
# Tiles the images.
|
189
|
+
tiled = Image.new("RGB", (wside * width + (wside - 1) * sep, hside * height + (hside - 1) * sep))
|
190
|
+
for i, image in enumerate(images):
|
191
|
+
x, y = i % wside, i // wside
|
192
|
+
tiled.paste(image, (x * (width + sep), y * (height + sep)))
|
193
|
+
|
194
|
+
return tiled
|
195
|
+
|
196
|
+
|
197
|
+
def as_numpy(array: Array) -> np.ndarray:
|
198
|
+
array = jax.device_get(array)
|
199
|
+
if jax.dtypes.issubdtype(array.dtype, jnp.floating):
|
200
|
+
array = array.astype(jnp.float32)
|
201
|
+
elif jax.dtypes.issubdtype(array.dtype, jnp.integer):
|
202
|
+
array = array.astype(jnp.int32)
|
203
|
+
elif jax.dtypes.issubdtype(array.dtype, jnp.bool_):
|
204
|
+
array = array.astype(jnp.bool_)
|
205
|
+
return np.array(array)
|
206
|
+
|
207
|
+
|
208
|
+
@dataclass(kw_only=True)
|
209
|
+
class LogImage:
|
210
|
+
image: PILImage
|
211
|
+
|
212
|
+
|
213
|
+
@dataclass(kw_only=True)
|
77
214
|
class LogVideo:
|
78
|
-
|
215
|
+
"""Container for video data and metadata.
|
79
216
|
|
217
|
+
Attributes:
|
218
|
+
frames: Video frames as a numpy array of shape (T,H,W,C)
|
219
|
+
fps: Frames per second
|
220
|
+
"""
|
80
221
|
|
81
|
-
|
82
|
-
|
83
|
-
xyz: Array
|
84
|
-
colors: Array | None
|
222
|
+
frames: np.ndarray
|
223
|
+
fps: int
|
85
224
|
|
86
225
|
|
87
|
-
@dataclass
|
226
|
+
@dataclass(kw_only=True)
|
88
227
|
class LogLine:
|
89
228
|
state: State
|
90
229
|
scalars: dict[str, dict[str, Number]]
|
91
230
|
strings: dict[str, dict[str, str]]
|
92
231
|
images: dict[str, dict[str, LogImage]]
|
93
|
-
audios: dict[str, dict[str, LogAudio]]
|
94
232
|
videos: dict[str, dict[str, LogVideo]]
|
95
|
-
point_cloud: dict[str, dict[str, LogPointCloud]]
|
96
233
|
|
97
234
|
|
98
|
-
@dataclass
|
235
|
+
@dataclass(kw_only=True)
|
99
236
|
class LogErrorSummary:
|
100
237
|
message: str
|
101
238
|
|
102
239
|
|
103
|
-
@dataclass
|
240
|
+
@dataclass(kw_only=True)
|
104
241
|
class LogError:
|
105
242
|
message: str
|
106
243
|
location: str | None = None
|
@@ -113,7 +250,7 @@ class LogError:
|
|
113
250
|
return message
|
114
251
|
|
115
252
|
|
116
|
-
@dataclass
|
253
|
+
@dataclass(kw_only=True)
|
117
254
|
class LogStatus:
|
118
255
|
message: str
|
119
256
|
created: float
|
@@ -121,7 +258,7 @@ class LogStatus:
|
|
121
258
|
lineno: int | None = None
|
122
259
|
|
123
260
|
|
124
|
-
@dataclass
|
261
|
+
@dataclass(kw_only=True)
|
125
262
|
class LogPing:
|
126
263
|
message: str
|
127
264
|
created: float
|
@@ -129,6 +266,120 @@ class LogPing:
|
|
129
266
|
lineno: int | None = None
|
130
267
|
|
131
268
|
|
269
|
+
def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int, int] | None = None) -> LogImage:
|
270
|
+
if not isinstance(image, (np.ndarray, Array, PILImage)):
|
271
|
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
272
|
+
if isinstance(image, Array):
|
273
|
+
image = as_numpy(image)
|
274
|
+
if isinstance(image, np.ndarray):
|
275
|
+
if image.ndim == 2:
|
276
|
+
image = np.expand_dims(image, axis=-1)
|
277
|
+
if image.ndim != 3:
|
278
|
+
raise RuntimeError(f"Expected image to have shape HW, HWC, or CHW, got {image.shape}")
|
279
|
+
|
280
|
+
# Normalizes the image and converts to integer.
|
281
|
+
if np.issubdtype(image.dtype, np.floating):
|
282
|
+
image = (normalize(image) * 255).round().astype(np.uint8)
|
283
|
+
elif image.dtype == np.uint8:
|
284
|
+
pass
|
285
|
+
else:
|
286
|
+
raise ValueError(f"Unsupported image dtype: {image.dtype}")
|
287
|
+
|
288
|
+
# Converts to a PIL image.
|
289
|
+
if image.shape[-1] == 1:
|
290
|
+
image = Image.fromarray(image[..., 0])
|
291
|
+
elif image.shape[-1] == 3:
|
292
|
+
image = Image.fromarray(image)
|
293
|
+
elif image.shape[0] == 1:
|
294
|
+
image = Image.fromarray(image[0])
|
295
|
+
elif image.shape[0] == 3:
|
296
|
+
image = Image.fromarray(image.transpose(1, 2, 0))
|
297
|
+
else:
|
298
|
+
raise ValueError(f"Unsupported image shape: {image.shape}")
|
299
|
+
|
300
|
+
if target_resolution is not None:
|
301
|
+
image = make_human_viewable_resolution(image, trg_res=target_resolution)
|
302
|
+
return LogImage(image=image)
|
303
|
+
|
304
|
+
|
305
|
+
def image_with_text(
|
306
|
+
image: PILImage,
|
307
|
+
text: list[str],
|
308
|
+
max_num_lines: int | None,
|
309
|
+
line_spacing: int,
|
310
|
+
centered: bool,
|
311
|
+
) -> LogImage:
|
312
|
+
"""Adds a text label to an image.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
image: The image to label, with shape (C, H, W)
|
316
|
+
text: The text label for the image
|
317
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
318
|
+
of the image
|
319
|
+
line_spacing: The spacing between adjacent lines
|
320
|
+
centered: If set, center the text labels, otherwise align to the left
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
The image with a text label
|
324
|
+
"""
|
325
|
+
if not text:
|
326
|
+
return LogImage(image=image)
|
327
|
+
if max_num_lines is None:
|
328
|
+
max_num_lines = len(text)
|
329
|
+
else:
|
330
|
+
text = text[:max_num_lines]
|
331
|
+
width, height = image.size
|
332
|
+
font: ImageFont.ImageFont = ImageFont.load_default()
|
333
|
+
_, _, _, line_height = font.getbbox(text[0])
|
334
|
+
new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
|
335
|
+
padded_image = Image.new(image.mode, (new_width, new_height), 255)
|
336
|
+
padded_image.paste(image, (0, 0))
|
337
|
+
drawer = ImageDraw.Draw(padded_image)
|
338
|
+
for i, text_line in enumerate(text):
|
339
|
+
text_line_top = height + line_spacing + i * (line_height + line_spacing)
|
340
|
+
if centered:
|
341
|
+
_, _, line_width, _ = font.getbbox(text_line)
|
342
|
+
text_line_left = (width - line_width) / 2
|
343
|
+
drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
|
344
|
+
else:
|
345
|
+
drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
|
346
|
+
return LogImage(image=padded_image)
|
347
|
+
|
348
|
+
|
349
|
+
def get_video(video: np.ndarray | Array, fps: int = 30) -> LogVideo:
|
350
|
+
"""Converts video data to standard format.
|
351
|
+
|
352
|
+
Args:
|
353
|
+
video: The video frames. Can be:
|
354
|
+
- A numpy array of shape (T, H, W, C) or (T, C, H, W)
|
355
|
+
- A JAX array of shape (T, H, W, C) or (T, C, H, W)
|
356
|
+
fps: Frames per second
|
357
|
+
|
358
|
+
Returns:
|
359
|
+
LogVideo containing standardized video frames
|
360
|
+
"""
|
361
|
+
if isinstance(video, Array):
|
362
|
+
video = as_numpy(video)
|
363
|
+
|
364
|
+
if not isinstance(video, np.ndarray):
|
365
|
+
raise ValueError(f"Unsupported video type: {type(video)}")
|
366
|
+
|
367
|
+
# Handle different dimension orderings
|
368
|
+
if video.ndim != 4:
|
369
|
+
raise ValueError(f"Expected video array of shape (T, H, W, C) or (T, C, H, W), got shape {video.shape}")
|
370
|
+
|
371
|
+
if video.shape[1] == 3: # (T,C,H,W) format
|
372
|
+
video = video.transpose(0, 2, 3, 1)
|
373
|
+
|
374
|
+
# Normalize and convert to uint8 if needed
|
375
|
+
if np.issubdtype(video.dtype, np.floating):
|
376
|
+
video = (normalize(video) * 255).round().astype(np.uint8)
|
377
|
+
elif video.dtype != np.uint8:
|
378
|
+
raise ValueError(f"Unsupported video dtype: {video.dtype}")
|
379
|
+
|
380
|
+
return LogVideo(frames=video, fps=fps)
|
381
|
+
|
382
|
+
|
132
383
|
class LoggerImpl(ABC):
|
133
384
|
def __init__(self, log_interval_seconds: float = 1.0) -> None:
|
134
385
|
"""Defines some default behavior for loggers.
|
@@ -187,25 +438,12 @@ class LoggerImpl(ABC):
|
|
187
438
|
ping: The ping to write.
|
188
439
|
"""
|
189
440
|
|
190
|
-
def
|
191
|
-
"""Logs
|
192
|
-
|
193
|
-
Args:
|
194
|
-
git_state: The Git state, as text blocks.
|
195
|
-
"""
|
196
|
-
|
197
|
-
def log_training_code(self, training_code: str) -> None:
|
198
|
-
"""Logs the training script code.
|
441
|
+
def log_file(self, name: str, contents: str) -> None:
|
442
|
+
"""Logs a large text file.
|
199
443
|
|
200
444
|
Args:
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
def log_config(self, config: DictConfig) -> None:
|
205
|
-
"""Logs the configuration for the current run.
|
206
|
-
|
207
|
-
Args:
|
208
|
-
config: The configuration, as a DictConfig.
|
445
|
+
name: The name of the file.
|
446
|
+
contents: The contents of the file.
|
209
447
|
"""
|
210
448
|
|
211
449
|
def should_log(self, state: State) -> bool:
|
@@ -260,11 +498,8 @@ class Logger:
|
|
260
498
|
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
261
499
|
self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
|
262
500
|
self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
|
263
|
-
self.images: dict[str, dict[str, Callable[[],
|
264
|
-
self.
|
265
|
-
self.videos: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
|
266
|
-
self.histograms: dict[str, dict[str, Callable[[], Array]]] = defaultdict(dict)
|
267
|
-
self.point_clouds: dict[str, dict[str, Callable[[], tuple[Array, Array | None]]]] = defaultdict(dict)
|
501
|
+
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
502
|
+
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
268
503
|
self.default_namespace = default_namespace
|
269
504
|
self.loggers: list[LoggerImpl] = []
|
270
505
|
|
@@ -272,6 +507,9 @@ class Logger:
|
|
272
507
|
root_logger = logging.getLogger()
|
273
508
|
ToastHandler(self).add_for_logger(root_logger)
|
274
509
|
|
510
|
+
# Flag when the logger is active.
|
511
|
+
self.active = False
|
512
|
+
|
275
513
|
def add_logger(self, *logger: LoggerImpl) -> None:
|
276
514
|
"""Add the logger, so that it gets called when `write` is called.
|
277
515
|
|
@@ -285,20 +523,15 @@ class Logger:
|
|
285
523
|
state=state,
|
286
524
|
scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
|
287
525
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
288
|
-
images={k: {kk:
|
289
|
-
|
290
|
-
videos={k: {kk: LogVideo(v()) for kk, v in v.items()} for k, v in self.videos.items()},
|
291
|
-
point_cloud={k: {kk: LogPointCloud(*v()) for kk, v in v.items()} for k, v in self.point_clouds.items()},
|
526
|
+
images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
|
527
|
+
videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
|
292
528
|
)
|
293
529
|
|
294
530
|
def clear(self) -> None:
|
295
531
|
self.scalars.clear()
|
296
532
|
self.strings.clear()
|
297
533
|
self.images.clear()
|
298
|
-
self.audio.clear()
|
299
534
|
self.videos.clear()
|
300
|
-
self.histograms.clear()
|
301
|
-
self.point_clouds.clear()
|
302
535
|
|
303
536
|
def write(self, state: State) -> None:
|
304
537
|
"""Writes the current step's logging information.
|
@@ -317,11 +550,11 @@ class Logger:
|
|
317
550
|
|
318
551
|
def write_error_summary(self, error_summary: str) -> None:
|
319
552
|
for logger in self.loggers:
|
320
|
-
logger.write_error_summary(LogErrorSummary(error_summary))
|
553
|
+
logger.write_error_summary(LogErrorSummary(message=error_summary))
|
321
554
|
|
322
555
|
def write_error(self, message: str, location: str | None = None) -> None:
|
323
556
|
for logger in self.loggers:
|
324
|
-
logger.write_error(LogError(message, location))
|
557
|
+
logger.write_error(LogError(message=message, location=location))
|
325
558
|
|
326
559
|
def write_status(
|
327
560
|
self,
|
@@ -330,7 +563,12 @@ class Logger:
|
|
330
563
|
lineno: int | None = None,
|
331
564
|
created: float | None = None,
|
332
565
|
) -> None:
|
333
|
-
status = LogStatus(
|
566
|
+
status = LogStatus(
|
567
|
+
message=message,
|
568
|
+
created=time.time() if created is None else created,
|
569
|
+
filename=filename,
|
570
|
+
lineno=lineno,
|
571
|
+
)
|
334
572
|
for logger in self.loggers:
|
335
573
|
logger.write_status(status)
|
336
574
|
|
@@ -341,7 +579,12 @@ class Logger:
|
|
341
579
|
lineno: int | None = None,
|
342
580
|
created: float | None = None,
|
343
581
|
) -> None:
|
344
|
-
ping = LogPing(
|
582
|
+
ping = LogPing(
|
583
|
+
message=message,
|
584
|
+
created=time.time() if created is None else created,
|
585
|
+
filename=filename,
|
586
|
+
lineno=lineno,
|
587
|
+
)
|
345
588
|
for logger in self.loggers:
|
346
589
|
logger.write_ping(ping)
|
347
590
|
|
@@ -356,8 +599,13 @@ class Logger:
|
|
356
599
|
value: The scalar value being logged
|
357
600
|
namespace: An optional logging namespace
|
358
601
|
"""
|
602
|
+
if not self.active:
|
603
|
+
raise RuntimeError("The logger is not active")
|
359
604
|
namespace = self.resolve_namespace(namespace)
|
360
605
|
|
606
|
+
if isinstance(value, jnp.ndarray):
|
607
|
+
assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
|
608
|
+
|
361
609
|
@functools.lru_cache(maxsize=None)
|
362
610
|
def scalar_future() -> Number:
|
363
611
|
return value() if callable(value) else value
|
@@ -372,6 +620,8 @@ class Logger:
|
|
372
620
|
value: The string value being logged
|
373
621
|
namespace: An optional logging namespace
|
374
622
|
"""
|
623
|
+
if not self.active:
|
624
|
+
raise RuntimeError("The logger is not active")
|
375
625
|
namespace = self.resolve_namespace(namespace)
|
376
626
|
|
377
627
|
@functools.lru_cache(maxsize=None)
|
@@ -383,71 +633,87 @@ class Logger:
|
|
383
633
|
def log_image(
|
384
634
|
self,
|
385
635
|
key: str,
|
386
|
-
value: Callable[[], Array] | Array,
|
636
|
+
value: Callable[[], np.ndarray | Array | PILImage] | np.ndarray | Array | PILImage,
|
387
637
|
*,
|
388
638
|
namespace: str | None = None,
|
389
|
-
|
639
|
+
target_resolution: tuple[int, int] | None = (512, 512),
|
390
640
|
) -> None:
|
391
641
|
"""Logs an image.
|
392
642
|
|
393
643
|
Args:
|
394
644
|
key: The key being logged
|
395
|
-
value: The image being logged
|
396
|
-
as an RGB (3 channel) or grayscale (1 channel) image
|
645
|
+
value: The image being logged
|
397
646
|
namespace: An optional logging namespace
|
398
|
-
|
399
|
-
|
400
|
-
resolution
|
647
|
+
target_resolution: The target resolution for each image; if None,
|
648
|
+
don't resample the images
|
401
649
|
"""
|
650
|
+
if not self.active:
|
651
|
+
raise RuntimeError("The logger is not active")
|
402
652
|
namespace = self.resolve_namespace(namespace)
|
403
653
|
|
404
654
|
@functools.lru_cache(maxsize=None)
|
405
|
-
def image_future() ->
|
406
|
-
|
655
|
+
def image_future() -> LogImage:
|
656
|
+
return get_image(value() if callable(value) else value, target_resolution)
|
407
657
|
|
408
658
|
self.images[namespace][key] = image_future
|
409
659
|
|
410
660
|
def log_labeled_image(
|
411
661
|
self,
|
412
662
|
key: str,
|
413
|
-
value: Callable[[], tuple[Array, str]] | tuple[Array, str],
|
663
|
+
value: Callable[[], tuple[np.ndarray | Array | PILImage, str]] | tuple[np.ndarray | Array | PILImage, str],
|
414
664
|
*,
|
415
665
|
namespace: str | None = None,
|
416
666
|
max_line_length: int | None = None,
|
417
|
-
|
667
|
+
max_num_lines: int | None = None,
|
668
|
+
target_resolution: tuple[int, int] | None = (512, 512),
|
669
|
+
line_spacing: int = 2,
|
418
670
|
centered: bool = True,
|
419
671
|
) -> None:
|
420
672
|
"""Logs an image with a label.
|
421
673
|
|
422
674
|
Args:
|
423
675
|
key: The key being logged
|
424
|
-
value: The image and label being logged
|
425
|
-
(H, W, C) or (H, W) as an RGB (3 channel) or grayscale
|
426
|
-
(1 channel) image
|
676
|
+
value: The image and label being logged
|
427
677
|
namespace: An optional logging namespace
|
428
|
-
max_line_length:
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
678
|
+
max_line_length: The maximum line length for the label
|
679
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
680
|
+
of the image
|
681
|
+
target_resolution: The target resolution for each image; if None,
|
682
|
+
don't resample the images
|
683
|
+
line_spacing: The spacing between adjacent lines
|
684
|
+
centered: If set, center the text labels, otherwise align to the left
|
434
685
|
"""
|
686
|
+
if not self.active:
|
687
|
+
raise RuntimeError("The logger is not active")
|
435
688
|
namespace = self.resolve_namespace(namespace)
|
436
689
|
|
437
690
|
@functools.lru_cache(maxsize=None)
|
438
|
-
def
|
439
|
-
|
691
|
+
def image_future() -> LogImage:
|
692
|
+
image, label = value() if callable(value) else value
|
693
|
+
image = get_image(image, target_resolution)
|
694
|
+
return image_with_text(
|
695
|
+
image.image,
|
696
|
+
standardize_text(label, max_line_length),
|
697
|
+
max_num_lines=max_num_lines,
|
698
|
+
line_spacing=line_spacing,
|
699
|
+
centered=centered,
|
700
|
+
)
|
440
701
|
|
441
|
-
self.images[namespace][key] =
|
702
|
+
self.images[namespace][key] = image_future
|
442
703
|
|
443
704
|
def log_images(
|
444
705
|
self,
|
445
706
|
key: str,
|
446
|
-
value:
|
707
|
+
value: (
|
708
|
+
Callable[[], Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array]
|
709
|
+
| Sequence[np.ndarray | Array | PILImage]
|
710
|
+
| np.ndarray
|
711
|
+
| Array
|
712
|
+
),
|
447
713
|
*,
|
448
714
|
namespace: str | None = None,
|
449
|
-
keep_resolution: bool = False,
|
450
715
|
max_images: int | None = None,
|
716
|
+
target_resolution: tuple[int, int] | None = (256, 256),
|
451
717
|
sep: int = 0,
|
452
718
|
) -> None:
|
453
719
|
"""Logs a set of images.
|
@@ -456,35 +722,49 @@ class Logger:
|
|
456
722
|
|
457
723
|
Args:
|
458
724
|
key: The key being logged
|
459
|
-
value: The images being logged
|
460
|
-
or (B H, W) as an RGB (3 channel) or grayscale (1 channel) image
|
725
|
+
value: The images being logged
|
461
726
|
namespace: An optional logging namespace
|
462
|
-
keep_resolution: If set, keep the image resolution the same,
|
463
|
-
otherwise upscale or downscale the image to a standard
|
464
|
-
resolution
|
465
727
|
max_images: The maximum number of images to show; extra images
|
466
728
|
are clipped
|
729
|
+
target_resolution: The target resolution for each image; if None,
|
730
|
+
don't resample the images
|
467
731
|
sep: An optional separation amount between adjacent images
|
468
732
|
"""
|
733
|
+
if not self.active:
|
734
|
+
raise RuntimeError("The logger is not active")
|
469
735
|
namespace = self.resolve_namespace(namespace)
|
470
736
|
|
471
737
|
@functools.lru_cache(maxsize=None)
|
472
|
-
def images_future() ->
|
473
|
-
|
738
|
+
def images_future() -> LogImage:
|
739
|
+
images = value() if callable(value) else value
|
740
|
+
if max_images is not None:
|
741
|
+
images = images[:max_images]
|
742
|
+
if isinstance(images, Array):
|
743
|
+
images = as_numpy(images)
|
744
|
+
if isinstance(images, Sequence):
|
745
|
+
images = list(images)
|
746
|
+
images = [get_image(image, target_resolution) for image in images]
|
747
|
+
tiled = tile_images([img.image for img in images], sep)
|
748
|
+
return LogImage(image=tiled)
|
474
749
|
|
475
750
|
self.images[namespace][key] = images_future
|
476
751
|
|
477
752
|
def log_labeled_images(
|
478
753
|
self,
|
479
754
|
key: str,
|
480
|
-
value:
|
755
|
+
value: (
|
756
|
+
Callable[[], tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]]
|
757
|
+
| tuple[Sequence[np.ndarray | Array | PILImage] | np.ndarray | Array, Sequence[str]]
|
758
|
+
),
|
481
759
|
*,
|
482
760
|
namespace: str | None = None,
|
483
|
-
max_line_length: int | None = None,
|
484
|
-
keep_resolution: bool = False,
|
485
761
|
max_images: int | None = None,
|
486
|
-
|
762
|
+
max_line_length: int | None = None,
|
763
|
+
max_num_lines: int | None = None,
|
764
|
+
target_resolution: tuple[int, int] | None = (256, 256),
|
765
|
+
line_spacing: int = 2,
|
487
766
|
centered: bool = True,
|
767
|
+
sep: int = 0,
|
488
768
|
) -> None:
|
489
769
|
"""Logs a set of images with labels.
|
490
770
|
|
@@ -492,353 +772,79 @@ class Logger:
|
|
492
772
|
|
493
773
|
Args:
|
494
774
|
key: The key being logged
|
495
|
-
value: The images and labels being logged
|
496
|
-
(B, C, H, W), (B, H, W, C) or (B, H, W) as an RGB (3 channel)
|
497
|
-
or grayscale (1 channel) image, with exactly B labels
|
775
|
+
value: The images and labels being logged
|
498
776
|
namespace: An optional logging namespace
|
499
|
-
max_line_length: Labels longer than this length are wrapped around
|
500
|
-
keep_resolution: If set, keep the image resolution the same,
|
501
|
-
otherwise upscale or downscale the image to a standard
|
502
|
-
resolution
|
503
777
|
max_images: The maximum number of images to show; extra images
|
504
778
|
are clipped
|
779
|
+
max_line_length: The maximum line length for the label
|
780
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
781
|
+
of the image
|
782
|
+
target_resolution: The target resolution for each image; if None,
|
783
|
+
don't resample the images
|
784
|
+
line_spacing: The spacing between adjacent lines
|
785
|
+
centered: If set, center the text labels, otherwise align to the left
|
505
786
|
sep: An optional separation amount between adjacent images
|
506
|
-
centered: If set, center the text labels, otherwise align to the
|
507
|
-
left
|
508
|
-
"""
|
509
|
-
namespace = self.resolve_namespace(namespace)
|
510
|
-
|
511
|
-
@functools.lru_cache(maxsize=None)
|
512
|
-
def labeled_images_future() -> Array:
|
513
|
-
raise NotImplementedError
|
514
|
-
|
515
|
-
self.images[namespace][key] = labeled_images_future
|
516
|
-
|
517
|
-
def log_audio(
|
518
|
-
self,
|
519
|
-
key: str,
|
520
|
-
value: Callable[[], Array] | Array,
|
521
|
-
*,
|
522
|
-
namespace: str | None = None,
|
523
|
-
sample_rate: int = 44100,
|
524
|
-
log_spec: bool = False,
|
525
|
-
n_fft_ms: float = 32.0,
|
526
|
-
hop_length_ms: float | None = None,
|
527
|
-
channel_select_mode: ChannelSelectMode = "first",
|
528
|
-
keep_resolution: bool = False,
|
529
|
-
) -> None:
|
530
|
-
"""Logs an audio clip.
|
531
|
-
|
532
|
-
Args:
|
533
|
-
key: The key being logged
|
534
|
-
value: The audio clip being logged; can be (C, T) or (T) as
|
535
|
-
a mono (1 channel) or stereo (2 channel) audio clip
|
536
|
-
namespace: An optional logging namespace
|
537
|
-
sample_rate: The sample rate of the audio clip
|
538
|
-
log_spec: If set, also log the spectrogram
|
539
|
-
n_fft_ms: FFT size, in milliseconds
|
540
|
-
hop_length_ms: The FFT hop length, in milliseconds
|
541
|
-
channel_select_mode: How to select the channel if the audio is
|
542
|
-
stereo; can be "first", "last", or "mean"; this is only used
|
543
|
-
for the spectrogram
|
544
|
-
keep_resolution: If set, keep the resolution of the
|
545
|
-
spectrogram; otherwise, make human-viewable
|
546
787
|
"""
|
788
|
+
if not self.active:
|
789
|
+
raise RuntimeError("The logger is not active")
|
547
790
|
namespace = self.resolve_namespace(namespace)
|
548
791
|
|
549
792
|
@functools.lru_cache(maxsize=None)
|
550
|
-
def
|
551
|
-
|
793
|
+
def images_future() -> LogImage:
|
794
|
+
images, labels = value() if callable(value) else value
|
795
|
+
if max_images is not None:
|
796
|
+
images = images[:max_images]
|
797
|
+
labels = labels[:max_images]
|
798
|
+
images = [get_image(image, target_resolution) for image in images]
|
799
|
+
labeled = [
|
800
|
+
image_with_text(
|
801
|
+
img.image,
|
802
|
+
standardize_text(label, max_line_length),
|
803
|
+
max_num_lines=max_num_lines,
|
804
|
+
line_spacing=line_spacing,
|
805
|
+
centered=centered,
|
806
|
+
)
|
807
|
+
for img, label in zip(images, labels)
|
808
|
+
]
|
809
|
+
tiled = tile_images([img.image for img in labeled], sep)
|
810
|
+
return LogImage(image=tiled)
|
552
811
|
|
553
|
-
|
554
|
-
def audio_future() -> tuple[Array, int]:
|
555
|
-
raise NotImplementedError
|
556
|
-
|
557
|
-
self.audio[namespace][key] = audio_future
|
558
|
-
|
559
|
-
if log_spec:
|
560
|
-
# Using a unique key for the spectrogram is very important because
|
561
|
-
# otherwise Tensorboard will have some issues.
|
562
|
-
self.log_spectrogram(
|
563
|
-
key=f"{key}_spec",
|
564
|
-
value=raw_audio_future,
|
565
|
-
namespace=namespace,
|
566
|
-
sample_rate=sample_rate,
|
567
|
-
n_fft_ms=n_fft_ms,
|
568
|
-
hop_length_ms=hop_length_ms,
|
569
|
-
channel_select_mode=channel_select_mode,
|
570
|
-
keep_resolution=keep_resolution,
|
571
|
-
)
|
572
|
-
|
573
|
-
def log_audios(
|
574
|
-
self,
|
575
|
-
key: str,
|
576
|
-
value: Callable[[], Array] | Array,
|
577
|
-
*,
|
578
|
-
namespace: str | None = None,
|
579
|
-
sep_ms: float = 0.0,
|
580
|
-
max_audios: int | None = None,
|
581
|
-
sample_rate: int = 44100,
|
582
|
-
log_spec: bool = False,
|
583
|
-
n_fft_ms: float = 32.0,
|
584
|
-
hop_length_ms: float | None = None,
|
585
|
-
channel_select_mode: ChannelSelectMode = "first",
|
586
|
-
spec_sep: int = 0,
|
587
|
-
keep_resolution: bool = False,
|
588
|
-
) -> None:
|
589
|
-
"""Logs multiple audio clips.
|
590
|
-
|
591
|
-
Args:
|
592
|
-
key: The key being logged
|
593
|
-
value: The audio clip being logged; can be (B, C, T) or (B, T) as
|
594
|
-
a mono (1 channel) or stereo (2 channel) audio clip, with
|
595
|
-
exactly B clips
|
596
|
-
namespace: An optional logging namespace
|
597
|
-
sep_ms: An optional separation amount between adjacent audio clips
|
598
|
-
max_audios: An optional maximum number of audio clips to log
|
599
|
-
sample_rate: The sample rate of the audio clip
|
600
|
-
log_spec: If set, also log the spectrogram
|
601
|
-
n_fft_ms: FFT size, in milliseconds
|
602
|
-
hop_length_ms: The FFT hop length, in milliseconds
|
603
|
-
channel_select_mode: How to select the channel if the audio is
|
604
|
-
stereo; can be "first", "last", or "mean"; this is only used
|
605
|
-
for the spectrogram
|
606
|
-
spec_sep: An optional separation amount between adjacent
|
607
|
-
spectrograms
|
608
|
-
keep_resolution: If set, keep the resolution of the
|
609
|
-
spectrogram; otherwise, make human-viewable
|
610
|
-
"""
|
611
|
-
namespace = self.resolve_namespace(namespace)
|
612
|
-
|
613
|
-
@functools.lru_cache(maxsize=None)
|
614
|
-
def raw_audio_future() -> Array:
|
615
|
-
raise NotImplementedError
|
616
|
-
|
617
|
-
@functools.lru_cache(maxsize=None)
|
618
|
-
def audio_future() -> tuple[Array, int]:
|
619
|
-
raise NotImplementedError
|
620
|
-
|
621
|
-
self.audio[namespace][key] = audio_future
|
622
|
-
|
623
|
-
if log_spec:
|
624
|
-
# Using a unique key for the spectrogram is very important because
|
625
|
-
# otherwise Tensorboard will have some issues.
|
626
|
-
self.log_spectrograms(
|
627
|
-
key=f"{key}_spec",
|
628
|
-
value=raw_audio_future,
|
629
|
-
namespace=namespace,
|
630
|
-
max_audios=max_audios,
|
631
|
-
sample_rate=sample_rate,
|
632
|
-
n_fft_ms=n_fft_ms,
|
633
|
-
hop_length_ms=hop_length_ms,
|
634
|
-
channel_select_mode=channel_select_mode,
|
635
|
-
spec_sep=spec_sep,
|
636
|
-
keep_resolution=keep_resolution,
|
637
|
-
)
|
638
|
-
|
639
|
-
def log_spectrogram(
|
640
|
-
self,
|
641
|
-
key: str,
|
642
|
-
value: Callable[[], Array] | Array,
|
643
|
-
*,
|
644
|
-
namespace: str | None = None,
|
645
|
-
sample_rate: int = 44100,
|
646
|
-
n_fft_ms: float = 32.0,
|
647
|
-
hop_length_ms: float | None = None,
|
648
|
-
channel_select_mode: ChannelSelectMode = "first",
|
649
|
-
keep_resolution: bool = False,
|
650
|
-
) -> None:
|
651
|
-
"""Logs spectrograms of an audio clip.
|
652
|
-
|
653
|
-
Args:
|
654
|
-
key: The key being logged
|
655
|
-
value: The audio clip being logged; can be (C, T) or (T) as
|
656
|
-
a mono (1 channel) or stereo (2 channel) audio clip
|
657
|
-
namespace: An optional logging namespace
|
658
|
-
sample_rate: The sample rate of the audio clip
|
659
|
-
n_fft_ms: FFT size, in milliseconds
|
660
|
-
hop_length_ms: The FFT hop length, in milliseconds
|
661
|
-
channel_select_mode: How to select the channel if the audio is
|
662
|
-
stereo; can be "first", "last", or "mean"; this is only used
|
663
|
-
for the spectrogram
|
664
|
-
keep_resolution: If set, keep the resolution of the
|
665
|
-
spectrogram; otherwise, make human-viewable
|
666
|
-
"""
|
667
|
-
namespace = self.resolve_namespace(namespace)
|
668
|
-
|
669
|
-
@functools.lru_cache(maxsize=None)
|
670
|
-
def spec_future() -> Array:
|
671
|
-
raise NotImplementedError
|
672
|
-
|
673
|
-
self.images[namespace][key] = spec_future
|
674
|
-
|
675
|
-
def log_spectrograms(
|
676
|
-
self,
|
677
|
-
key: str,
|
678
|
-
value: Callable[[], Array] | Array,
|
679
|
-
*,
|
680
|
-
namespace: str | None = None,
|
681
|
-
max_audios: int | None = None,
|
682
|
-
sample_rate: int = 44100,
|
683
|
-
n_fft_ms: float = 32.0,
|
684
|
-
hop_length_ms: float | None = None,
|
685
|
-
channel_select_mode: ChannelSelectMode = "first",
|
686
|
-
spec_sep: int = 0,
|
687
|
-
keep_resolution: bool = False,
|
688
|
-
) -> None:
|
689
|
-
"""Logs spectrograms of audio clips.
|
690
|
-
|
691
|
-
Args:
|
692
|
-
key: The key being logged
|
693
|
-
value: The audio clip being logged; can be (B, C, T) or (B, T) as
|
694
|
-
a mono (1 channel) or stereo (2 channel) audio clip, with
|
695
|
-
exactly B clips
|
696
|
-
namespace: An optional logging namespace
|
697
|
-
max_audios: An optional maximum number of audio clips to log
|
698
|
-
sample_rate: The sample rate of the audio clip
|
699
|
-
n_fft_ms: FFT size, in milliseconds
|
700
|
-
hop_length_ms: The FFT hop length, in milliseconds
|
701
|
-
channel_select_mode: How to select the channel if the audio is
|
702
|
-
stereo; can be "first", "last", or "mean"; this is only used
|
703
|
-
for the spectrogram
|
704
|
-
spec_sep: An optional separation amount between adjacent
|
705
|
-
spectrograms
|
706
|
-
keep_resolution: If set, keep the resolution of the
|
707
|
-
spectrogram; otherwise, make human-viewable
|
708
|
-
"""
|
709
|
-
namespace = self.resolve_namespace(namespace)
|
710
|
-
|
711
|
-
@functools.lru_cache(maxsize=None)
|
712
|
-
def spec_future() -> Array:
|
713
|
-
raise NotImplementedError
|
812
|
+
self.images[namespace][key] = images_future
|
714
813
|
|
715
|
-
|
814
|
+
def log_file(self, name: str, contents: str) -> None:
|
815
|
+
for logger in self.loggers:
|
816
|
+
logger.log_file(name, contents)
|
716
817
|
|
717
818
|
def log_video(
|
718
819
|
self,
|
719
820
|
key: str,
|
720
|
-
value: Callable[[], Array] | Array,
|
821
|
+
value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
|
721
822
|
*,
|
823
|
+
fps: int = 30,
|
722
824
|
namespace: str | None = None,
|
723
|
-
fps: int | None = None,
|
724
|
-
length: float | None = None,
|
725
825
|
) -> None:
|
726
826
|
"""Logs a video.
|
727
827
|
|
728
828
|
Args:
|
729
829
|
key: The key being logged
|
730
|
-
value: The video
|
731
|
-
(T,
|
732
|
-
(
|
830
|
+
value: The video frames. Can be:
|
831
|
+
- A numpy array of shape (T,H,W,C) or (T,C,H,W)
|
832
|
+
- A JAX array of shape (T,H,W,C) or (T,C,H,W)
|
833
|
+
fps: Frames per second
|
733
834
|
namespace: An optional logging namespace
|
734
|
-
fps: The video frames per second
|
735
|
-
length: The desired video length, in seconds, at the target FPS
|
736
835
|
"""
|
836
|
+
if not self.active:
|
837
|
+
raise RuntimeError("The logger is not active")
|
737
838
|
namespace = self.resolve_namespace(namespace)
|
738
839
|
|
739
840
|
@functools.lru_cache(maxsize=None)
|
740
|
-
def video_future() ->
|
741
|
-
|
841
|
+
def video_future() -> LogVideo:
|
842
|
+
return get_video(value() if callable(value) else value, fps=fps)
|
742
843
|
|
743
844
|
self.videos[namespace][key] = video_future
|
744
845
|
|
745
|
-
def log_videos(
|
746
|
-
self,
|
747
|
-
key: str,
|
748
|
-
value: Callable[[], Array | list[Array]] | Array | list[Array],
|
749
|
-
*,
|
750
|
-
namespace: str | None = None,
|
751
|
-
max_videos: int | None = None,
|
752
|
-
sep: int = 0,
|
753
|
-
fps: int | None = None,
|
754
|
-
length: int | None = None,
|
755
|
-
) -> None:
|
756
|
-
"""Logs a set of video.
|
757
|
-
|
758
|
-
Args:
|
759
|
-
key: The key being logged
|
760
|
-
value: The videos being logged; the video can be (B, T, C, H, W),
|
761
|
-
(B, T, H, W, C) or (B T, H, W) as an RGB (3 channel) or
|
762
|
-
grayscale (1 channel) video
|
763
|
-
namespace: An optional logging namespace
|
764
|
-
max_videos: The maximum number of videos to show; extra images
|
765
|
-
are clipped
|
766
|
-
sep: An optional separation amount between adjacent videos
|
767
|
-
fps: The video frames per second
|
768
|
-
length: The desired video length, in seconds, at the target FPS
|
769
|
-
"""
|
770
|
-
namespace = self.resolve_namespace(namespace)
|
771
|
-
|
772
|
-
@functools.lru_cache(maxsize=None)
|
773
|
-
def videos_future() -> Array:
|
774
|
-
raise NotImplementedError
|
775
|
-
|
776
|
-
self.videos[namespace][key] = videos_future
|
777
|
-
|
778
|
-
def log_histogram(
|
779
|
-
self,
|
780
|
-
key: str,
|
781
|
-
value: Callable[[], Array] | Array,
|
782
|
-
*,
|
783
|
-
namespace: str | None = None,
|
784
|
-
) -> None:
|
785
|
-
"""Logs a histogram.
|
786
|
-
|
787
|
-
Args:
|
788
|
-
key: The key being logged
|
789
|
-
value: The values to create a histogram from, with arbitrary shape
|
790
|
-
namespace: An optional logging namespace
|
791
|
-
"""
|
792
|
-
namespace = self.resolve_namespace(namespace)
|
793
|
-
|
794
|
-
@functools.lru_cache(maxsize=None)
|
795
|
-
def histogram_future() -> Array:
|
796
|
-
raise NotImplementedError
|
797
|
-
|
798
|
-
self.histograms[namespace][key] = histogram_future
|
799
|
-
|
800
|
-
def log_point_cloud(
|
801
|
-
self,
|
802
|
-
key: str,
|
803
|
-
value: Callable[[], Array] | Array,
|
804
|
-
*,
|
805
|
-
namespace: str | None = None,
|
806
|
-
max_points: int = 1000,
|
807
|
-
colors: Callable[[], Array] | Array | None = None,
|
808
|
-
) -> None:
|
809
|
-
"""Logs a point cloud.
|
810
|
-
|
811
|
-
Args:
|
812
|
-
key: The key being logged
|
813
|
-
value: The point cloud values, with shape (N, 3) or (B, ..., 3);
|
814
|
-
can pass multiple batches in order to show multiple point
|
815
|
-
clouds
|
816
|
-
namespace: An optional logging namespace
|
817
|
-
max_points: An optional maximum number of points in the point cloud
|
818
|
-
colors: An optional color for each point, with the same shape as
|
819
|
-
the point cloud
|
820
|
-
"""
|
821
|
-
namespace = self.resolve_namespace(namespace)
|
822
|
-
|
823
|
-
@functools.lru_cache(maxsize=None)
|
824
|
-
def point_cloud_future() -> tuple[Array, Array | None]:
|
825
|
-
raise NotImplementedError
|
826
|
-
|
827
|
-
self.point_clouds[namespace][key] = point_cloud_future
|
828
|
-
|
829
|
-
def log_git_state(self, git_state: str) -> None:
|
830
|
-
for logger in self.loggers:
|
831
|
-
logger.log_git_state(git_state)
|
832
|
-
|
833
|
-
def log_training_code(self, training_code: str) -> None:
|
834
|
-
for logger in self.loggers:
|
835
|
-
logger.log_training_code(training_code)
|
836
|
-
|
837
|
-
def log_config(self, config: DictConfig) -> None:
|
838
|
-
for logger in self.loggers:
|
839
|
-
logger.log_config(config)
|
840
|
-
|
841
846
|
def __enter__(self) -> Self:
|
847
|
+
self.active = True
|
842
848
|
for logger in self.loggers:
|
843
849
|
logger.start()
|
844
850
|
return self
|
@@ -846,3 +852,4 @@ class Logger:
|
|
846
852
|
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
|
847
853
|
for logger in self.loggers:
|
848
854
|
logger.stop()
|
855
|
+
self.active = False
|