xax 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +102 -2
- xax/core/conf.py +8 -33
- xax/core/state.py +13 -23
- xax/nn/geom.py +75 -0
- xax/requirements.txt +2 -0
- xax/task/base.py +2 -0
- xax/task/logger.py +194 -122
- xax/task/loggers/callback.py +4 -16
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +14 -28
- xax/task/mixins/__init__.py +1 -0
- xax/task/mixins/artifacts.py +7 -4
- xax/task/mixins/checkpointing.py +12 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +16 -5
- xax/task/mixins/data_loader.py +23 -12
- xax/task/mixins/gpu_stats.py +19 -5
- xax/task/mixins/logger.py +4 -2
- xax/task/mixins/process.py +4 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +189 -129
- xax/task/script.py +1 -1
- xax/task/task.py +7 -0
- xax/utils/jax.py +126 -0
- xax/utils/profile.py +61 -0
- xax/utils/pytree.py +50 -0
- xax/utils/tensorboard.py +48 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/METADATA +12 -2
- xax-0.0.7.dist-info/RECORD +55 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.5.dist-info/RECORD +0 -52
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/LICENSE +0 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/top_level.txt +0 -0
xax/task/logger.py
CHANGED
@@ -24,7 +24,6 @@ import jax
|
|
24
24
|
import jax.numpy as jnp
|
25
25
|
import numpy as np
|
26
26
|
from jaxtyping import Array
|
27
|
-
from omegaconf import DictConfig
|
28
27
|
from PIL import Image, ImageDraw, ImageFont
|
29
28
|
from PIL.Image import Image as PILImage
|
30
29
|
|
@@ -92,50 +91,6 @@ def make_human_viewable_resolution(
|
|
92
91
|
return image.resize((new_width, new_height), interpolation)
|
93
92
|
|
94
93
|
|
95
|
-
def image_with_text(
|
96
|
-
image: PILImage,
|
97
|
-
text: list[str],
|
98
|
-
max_num_lines: int | None,
|
99
|
-
line_spacing: int,
|
100
|
-
centered: bool,
|
101
|
-
) -> PILImage:
|
102
|
-
"""Adds a text label to an image.
|
103
|
-
|
104
|
-
Args:
|
105
|
-
image: The image to label, with shape (C, H, W)
|
106
|
-
text: The text label for the image
|
107
|
-
max_num_lines: The number of lines of spacing to add to the bottom
|
108
|
-
of the image
|
109
|
-
line_spacing: The spacing between adjacent lines
|
110
|
-
centered: If set, center the text labels, otherwise align to the left
|
111
|
-
|
112
|
-
Returns:
|
113
|
-
The image with a text label
|
114
|
-
"""
|
115
|
-
if not text:
|
116
|
-
return image
|
117
|
-
if max_num_lines is None:
|
118
|
-
max_num_lines = len(text)
|
119
|
-
else:
|
120
|
-
text = text[:max_num_lines]
|
121
|
-
width, height = image.size
|
122
|
-
font: ImageFont.ImageFont = ImageFont.load_default()
|
123
|
-
_, _, _, line_height = font.getbbox(text[0])
|
124
|
-
new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
|
125
|
-
padded_image = Image.new(image.mode, (new_width, new_height), 255)
|
126
|
-
padded_image.paste(image, (0, 0))
|
127
|
-
drawer = ImageDraw.Draw(padded_image)
|
128
|
-
for i, text_line in enumerate(text):
|
129
|
-
text_line_top = height + line_spacing + i * (line_height + line_spacing)
|
130
|
-
if centered:
|
131
|
-
_, _, line_width, _ = font.getbbox(text_line)
|
132
|
-
text_line_left = (width - line_width) / 2
|
133
|
-
drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
|
134
|
-
else:
|
135
|
-
drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
|
136
|
-
return padded_image
|
137
|
-
|
138
|
-
|
139
94
|
class namespace_context: # noqa: N801
|
140
95
|
def __init__(self, name: str | None) -> None:
|
141
96
|
self._name = name
|
@@ -250,7 +205,68 @@ def as_numpy(array: Array) -> np.ndarray:
|
|
250
205
|
return np.array(array)
|
251
206
|
|
252
207
|
|
253
|
-
|
208
|
+
@dataclass(kw_only=True)
|
209
|
+
class LogImage:
|
210
|
+
image: PILImage
|
211
|
+
|
212
|
+
|
213
|
+
@dataclass(kw_only=True)
|
214
|
+
class LogVideo:
|
215
|
+
"""Container for video data and metadata.
|
216
|
+
|
217
|
+
Attributes:
|
218
|
+
frames: Video frames as a numpy array of shape (T,H,W,C)
|
219
|
+
fps: Frames per second
|
220
|
+
"""
|
221
|
+
|
222
|
+
frames: np.ndarray
|
223
|
+
fps: int
|
224
|
+
|
225
|
+
|
226
|
+
@dataclass(kw_only=True)
|
227
|
+
class LogLine:
|
228
|
+
state: State
|
229
|
+
scalars: dict[str, dict[str, Number]]
|
230
|
+
strings: dict[str, dict[str, str]]
|
231
|
+
images: dict[str, dict[str, LogImage]]
|
232
|
+
videos: dict[str, dict[str, LogVideo]]
|
233
|
+
|
234
|
+
|
235
|
+
@dataclass(kw_only=True)
|
236
|
+
class LogErrorSummary:
|
237
|
+
message: str
|
238
|
+
|
239
|
+
|
240
|
+
@dataclass(kw_only=True)
|
241
|
+
class LogError:
|
242
|
+
message: str
|
243
|
+
location: str | None = None
|
244
|
+
|
245
|
+
@property
|
246
|
+
def message_with_location(self) -> str:
|
247
|
+
message = self.message
|
248
|
+
if self.location is not None:
|
249
|
+
message += f" ({self.location})"
|
250
|
+
return message
|
251
|
+
|
252
|
+
|
253
|
+
@dataclass(kw_only=True)
|
254
|
+
class LogStatus:
|
255
|
+
message: str
|
256
|
+
created: float
|
257
|
+
filename: str | None = None
|
258
|
+
lineno: int | None = None
|
259
|
+
|
260
|
+
|
261
|
+
@dataclass(kw_only=True)
|
262
|
+
class LogPing:
|
263
|
+
message: str
|
264
|
+
created: float
|
265
|
+
filename: str | None = None
|
266
|
+
lineno: int | None = None
|
267
|
+
|
268
|
+
|
269
|
+
def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int, int] | None = None) -> LogImage:
|
254
270
|
if not isinstance(image, (np.ndarray, Array, PILImage)):
|
255
271
|
raise ValueError(f"Unsupported image type: {type(image)}")
|
256
272
|
if isinstance(image, Array):
|
@@ -283,54 +299,85 @@ def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int
|
|
283
299
|
|
284
300
|
if target_resolution is not None:
|
285
301
|
image = make_human_viewable_resolution(image, trg_res=target_resolution)
|
286
|
-
return image
|
302
|
+
return LogImage(image=image)
|
287
303
|
|
288
304
|
|
289
|
-
|
290
|
-
|
291
|
-
|
305
|
+
def image_with_text(
|
306
|
+
image: PILImage,
|
307
|
+
text: list[str],
|
308
|
+
max_num_lines: int | None,
|
309
|
+
line_spacing: int,
|
310
|
+
centered: bool,
|
311
|
+
) -> LogImage:
|
312
|
+
"""Adds a text label to an image.
|
292
313
|
|
314
|
+
Args:
|
315
|
+
image: The image to label, with shape (C, H, W)
|
316
|
+
text: The text label for the image
|
317
|
+
max_num_lines: The number of lines of spacing to add to the bottom
|
318
|
+
of the image
|
319
|
+
line_spacing: The spacing between adjacent lines
|
320
|
+
centered: If set, center the text labels, otherwise align to the left
|
293
321
|
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
322
|
+
Returns:
|
323
|
+
The image with a text label
|
324
|
+
"""
|
325
|
+
if not text:
|
326
|
+
return LogImage(image=image)
|
327
|
+
if max_num_lines is None:
|
328
|
+
max_num_lines = len(text)
|
329
|
+
else:
|
330
|
+
text = text[:max_num_lines]
|
331
|
+
width, height = image.size
|
332
|
+
font: ImageFont.ImageFont = ImageFont.load_default()
|
333
|
+
_, _, _, line_height = font.getbbox(text[0])
|
334
|
+
new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
|
335
|
+
padded_image = Image.new(image.mode, (new_width, new_height), 255)
|
336
|
+
padded_image.paste(image, (0, 0))
|
337
|
+
drawer = ImageDraw.Draw(padded_image)
|
338
|
+
for i, text_line in enumerate(text):
|
339
|
+
text_line_top = height + line_spacing + i * (line_height + line_spacing)
|
340
|
+
if centered:
|
341
|
+
_, _, line_width, _ = font.getbbox(text_line)
|
342
|
+
text_line_left = (width - line_width) / 2
|
343
|
+
drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
|
344
|
+
else:
|
345
|
+
drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
|
346
|
+
return LogImage(image=padded_image)
|
300
347
|
|
301
348
|
|
302
|
-
|
303
|
-
|
304
|
-
message: str
|
349
|
+
def get_video(video: np.ndarray | Array, fps: int = 30) -> LogVideo:
|
350
|
+
"""Converts video data to standard format.
|
305
351
|
|
352
|
+
Args:
|
353
|
+
video: The video frames. Can be:
|
354
|
+
- A numpy array of shape (T, H, W, C) or (T, C, H, W)
|
355
|
+
- A JAX array of shape (T, H, W, C) or (T, C, H, W)
|
356
|
+
fps: Frames per second
|
306
357
|
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
358
|
+
Returns:
|
359
|
+
LogVideo containing standardized video frames
|
360
|
+
"""
|
361
|
+
if isinstance(video, Array):
|
362
|
+
video = as_numpy(video)
|
311
363
|
|
312
|
-
|
313
|
-
|
314
|
-
message = self.message
|
315
|
-
if self.location is not None:
|
316
|
-
message += f" ({self.location})"
|
317
|
-
return message
|
364
|
+
if not isinstance(video, np.ndarray):
|
365
|
+
raise ValueError(f"Unsupported video type: {type(video)}")
|
318
366
|
|
367
|
+
# Handle different dimension orderings
|
368
|
+
if video.ndim != 4:
|
369
|
+
raise ValueError(f"Expected video array of shape (T, H, W, C) or (T, C, H, W), got shape {video.shape}")
|
319
370
|
|
320
|
-
|
321
|
-
|
322
|
-
message: str
|
323
|
-
created: float
|
324
|
-
filename: str | None = None
|
325
|
-
lineno: int | None = None
|
371
|
+
if video.shape[1] == 3: # (T,C,H,W) format
|
372
|
+
video = video.transpose(0, 2, 3, 1)
|
326
373
|
|
374
|
+
# Normalize and convert to uint8 if needed
|
375
|
+
if np.issubdtype(video.dtype, np.floating):
|
376
|
+
video = (normalize(video) * 255).round().astype(np.uint8)
|
377
|
+
elif video.dtype != np.uint8:
|
378
|
+
raise ValueError(f"Unsupported video dtype: {video.dtype}")
|
327
379
|
|
328
|
-
|
329
|
-
class LogPing:
|
330
|
-
message: str
|
331
|
-
created: float
|
332
|
-
filename: str | None = None
|
333
|
-
lineno: int | None = None
|
380
|
+
return LogVideo(frames=video, fps=fps)
|
334
381
|
|
335
382
|
|
336
383
|
class LoggerImpl(ABC):
|
@@ -391,25 +438,12 @@ class LoggerImpl(ABC):
|
|
391
438
|
ping: The ping to write.
|
392
439
|
"""
|
393
440
|
|
394
|
-
def
|
395
|
-
"""Logs
|
396
|
-
|
397
|
-
Args:
|
398
|
-
git_state: The Git state, as text blocks.
|
399
|
-
"""
|
400
|
-
|
401
|
-
def log_training_code(self, training_code: str) -> None:
|
402
|
-
"""Logs the training script code.
|
403
|
-
|
404
|
-
Args:
|
405
|
-
training_code: The training script code.
|
406
|
-
"""
|
407
|
-
|
408
|
-
def log_config(self, config: DictConfig) -> None:
|
409
|
-
"""Logs the configuration for the current run.
|
441
|
+
def log_file(self, name: str, contents: str) -> None:
|
442
|
+
"""Logs a large text file.
|
410
443
|
|
411
444
|
Args:
|
412
|
-
|
445
|
+
name: The name of the file.
|
446
|
+
contents: The contents of the file.
|
413
447
|
"""
|
414
448
|
|
415
449
|
def should_log(self, state: State) -> bool:
|
@@ -464,7 +498,8 @@ class Logger:
|
|
464
498
|
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
465
499
|
self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
|
466
500
|
self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
|
467
|
-
self.images: dict[str, dict[str, Callable[[],
|
501
|
+
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
502
|
+
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
468
503
|
self.default_namespace = default_namespace
|
469
504
|
self.loggers: list[LoggerImpl] = []
|
470
505
|
|
@@ -488,13 +523,15 @@ class Logger:
|
|
488
523
|
state=state,
|
489
524
|
scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
|
490
525
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
491
|
-
images={k: {kk:
|
526
|
+
images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
|
527
|
+
videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
|
492
528
|
)
|
493
529
|
|
494
530
|
def clear(self) -> None:
|
495
531
|
self.scalars.clear()
|
496
532
|
self.strings.clear()
|
497
533
|
self.images.clear()
|
534
|
+
self.videos.clear()
|
498
535
|
|
499
536
|
def write(self, state: State) -> None:
|
500
537
|
"""Writes the current step's logging information.
|
@@ -513,11 +550,11 @@ class Logger:
|
|
513
550
|
|
514
551
|
def write_error_summary(self, error_summary: str) -> None:
|
515
552
|
for logger in self.loggers:
|
516
|
-
logger.write_error_summary(LogErrorSummary(error_summary))
|
553
|
+
logger.write_error_summary(LogErrorSummary(message=error_summary))
|
517
554
|
|
518
555
|
def write_error(self, message: str, location: str | None = None) -> None:
|
519
556
|
for logger in self.loggers:
|
520
|
-
logger.write_error(LogError(message, location))
|
557
|
+
logger.write_error(LogError(message=message, location=location))
|
521
558
|
|
522
559
|
def write_status(
|
523
560
|
self,
|
@@ -526,7 +563,12 @@ class Logger:
|
|
526
563
|
lineno: int | None = None,
|
527
564
|
created: float | None = None,
|
528
565
|
) -> None:
|
529
|
-
status = LogStatus(
|
566
|
+
status = LogStatus(
|
567
|
+
message=message,
|
568
|
+
created=time.time() if created is None else created,
|
569
|
+
filename=filename,
|
570
|
+
lineno=lineno,
|
571
|
+
)
|
530
572
|
for logger in self.loggers:
|
531
573
|
logger.write_status(status)
|
532
574
|
|
@@ -537,7 +579,12 @@ class Logger:
|
|
537
579
|
lineno: int | None = None,
|
538
580
|
created: float | None = None,
|
539
581
|
) -> None:
|
540
|
-
ping = LogPing(
|
582
|
+
ping = LogPing(
|
583
|
+
message=message,
|
584
|
+
created=time.time() if created is None else created,
|
585
|
+
filename=filename,
|
586
|
+
lineno=lineno,
|
587
|
+
)
|
541
588
|
for logger in self.loggers:
|
542
589
|
logger.write_ping(ping)
|
543
590
|
|
@@ -556,6 +603,9 @@ class Logger:
|
|
556
603
|
raise RuntimeError("The logger is not active")
|
557
604
|
namespace = self.resolve_namespace(namespace)
|
558
605
|
|
606
|
+
if isinstance(value, jnp.ndarray):
|
607
|
+
assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
|
608
|
+
|
559
609
|
@functools.lru_cache(maxsize=None)
|
560
610
|
def scalar_future() -> Number:
|
561
611
|
return value() if callable(value) else value
|
@@ -602,7 +652,7 @@ class Logger:
|
|
602
652
|
namespace = self.resolve_namespace(namespace)
|
603
653
|
|
604
654
|
@functools.lru_cache(maxsize=None)
|
605
|
-
def image_future() ->
|
655
|
+
def image_future() -> LogImage:
|
606
656
|
return get_image(value() if callable(value) else value, target_resolution)
|
607
657
|
|
608
658
|
self.images[namespace][key] = image_future
|
@@ -638,11 +688,11 @@ class Logger:
|
|
638
688
|
namespace = self.resolve_namespace(namespace)
|
639
689
|
|
640
690
|
@functools.lru_cache(maxsize=None)
|
641
|
-
def image_future() ->
|
691
|
+
def image_future() -> LogImage:
|
642
692
|
image, label = value() if callable(value) else value
|
643
693
|
image = get_image(image, target_resolution)
|
644
694
|
return image_with_text(
|
645
|
-
image,
|
695
|
+
image.image,
|
646
696
|
standardize_text(label, max_line_length),
|
647
697
|
max_num_lines=max_num_lines,
|
648
698
|
line_spacing=line_spacing,
|
@@ -685,7 +735,7 @@ class Logger:
|
|
685
735
|
namespace = self.resolve_namespace(namespace)
|
686
736
|
|
687
737
|
@functools.lru_cache(maxsize=None)
|
688
|
-
def images_future() ->
|
738
|
+
def images_future() -> LogImage:
|
689
739
|
images = value() if callable(value) else value
|
690
740
|
if max_images is not None:
|
691
741
|
images = images[:max_images]
|
@@ -694,7 +744,8 @@ class Logger:
|
|
694
744
|
if isinstance(images, Sequence):
|
695
745
|
images = list(images)
|
696
746
|
images = [get_image(image, target_resolution) for image in images]
|
697
|
-
|
747
|
+
tiled = tile_images([img.image for img in images], sep)
|
748
|
+
return LogImage(image=tiled)
|
698
749
|
|
699
750
|
self.images[namespace][key] = images_future
|
700
751
|
|
@@ -739,37 +790,58 @@ class Logger:
|
|
739
790
|
namespace = self.resolve_namespace(namespace)
|
740
791
|
|
741
792
|
@functools.lru_cache(maxsize=None)
|
742
|
-
def images_future() ->
|
793
|
+
def images_future() -> LogImage:
|
743
794
|
images, labels = value() if callable(value) else value
|
744
795
|
if max_images is not None:
|
745
796
|
images = images[:max_images]
|
746
797
|
labels = labels[:max_images]
|
747
798
|
images = [get_image(image, target_resolution) for image in images]
|
748
|
-
|
799
|
+
labeled = [
|
749
800
|
image_with_text(
|
750
|
-
image,
|
801
|
+
img.image,
|
751
802
|
standardize_text(label, max_line_length),
|
752
803
|
max_num_lines=max_num_lines,
|
753
804
|
line_spacing=line_spacing,
|
754
805
|
centered=centered,
|
755
806
|
)
|
756
|
-
for
|
807
|
+
for img, label in zip(images, labels)
|
757
808
|
]
|
758
|
-
|
809
|
+
tiled = tile_images([img.image for img in labeled], sep)
|
810
|
+
return LogImage(image=tiled)
|
759
811
|
|
760
812
|
self.images[namespace][key] = images_future
|
761
813
|
|
762
|
-
def
|
814
|
+
def log_file(self, name: str, contents: str) -> None:
|
763
815
|
for logger in self.loggers:
|
764
|
-
logger.
|
816
|
+
logger.log_file(name, contents)
|
765
817
|
|
766
|
-
def
|
767
|
-
|
768
|
-
|
818
|
+
def log_video(
|
819
|
+
self,
|
820
|
+
key: str,
|
821
|
+
value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
|
822
|
+
*,
|
823
|
+
fps: int = 30,
|
824
|
+
namespace: str | None = None,
|
825
|
+
) -> None:
|
826
|
+
"""Logs a video.
|
769
827
|
|
770
|
-
|
771
|
-
|
772
|
-
|
828
|
+
Args:
|
829
|
+
key: The key being logged
|
830
|
+
value: The video frames. Can be:
|
831
|
+
- A numpy array of shape (T,H,W,C) or (T,C,H,W)
|
832
|
+
- A JAX array of shape (T,H,W,C) or (T,C,H,W)
|
833
|
+
fps: Frames per second
|
834
|
+
namespace: An optional logging namespace
|
835
|
+
"""
|
836
|
+
if not self.active:
|
837
|
+
raise RuntimeError("The logger is not active")
|
838
|
+
namespace = self.resolve_namespace(namespace)
|
839
|
+
|
840
|
+
@functools.lru_cache(maxsize=None)
|
841
|
+
def video_future() -> LogVideo:
|
842
|
+
return get_video(value() if callable(value) else value, fps=fps)
|
843
|
+
|
844
|
+
self.videos[namespace][key] = video_future
|
773
845
|
|
774
846
|
def __enter__(self) -> Self:
|
775
847
|
self.active = True
|
xax/task/loggers/callback.py
CHANGED
@@ -2,8 +2,6 @@
|
|
2
2
|
|
3
3
|
from typing import Callable
|
4
4
|
|
5
|
-
from omegaconf import DictConfig
|
6
|
-
|
7
5
|
from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
|
8
6
|
|
9
7
|
|
@@ -16,9 +14,7 @@ class CallbackLogger(LoggerImpl):
|
|
16
14
|
error_callback: Callable[[LogError], None] = lambda x: None,
|
17
15
|
status_callback: Callable[[LogStatus], None] = lambda x: None,
|
18
16
|
ping_callback: Callable[[LogPing], None] = lambda x: None,
|
19
|
-
|
20
|
-
training_code_callback: Callable[[str], None] = lambda x: None,
|
21
|
-
config_callback: Callable[[DictConfig], None] = lambda x: None,
|
17
|
+
file_callback: Callable[[str, str], None] = lambda x, y: None,
|
22
18
|
) -> None:
|
23
19
|
super().__init__()
|
24
20
|
|
@@ -27,9 +23,7 @@ class CallbackLogger(LoggerImpl):
|
|
27
23
|
self.error_callback = error_callback
|
28
24
|
self.status_callback = status_callback
|
29
25
|
self.ping_callback = ping_callback
|
30
|
-
self.
|
31
|
-
self.training_code_callback = training_code_callback
|
32
|
-
self.config_callback = config_callback
|
26
|
+
self.file_callback = file_callback
|
33
27
|
|
34
28
|
def write(self, line: LogLine) -> None:
|
35
29
|
self.callback(line)
|
@@ -46,11 +40,5 @@ class CallbackLogger(LoggerImpl):
|
|
46
40
|
def write_ping(self, ping: LogPing) -> None:
|
47
41
|
self.ping_callback(ping)
|
48
42
|
|
49
|
-
def
|
50
|
-
self.
|
51
|
-
|
52
|
-
def log_training_code(self, training_code: str) -> None:
|
53
|
-
self.training_code_callback(training_code)
|
54
|
-
|
55
|
-
def log_config(self, config: DictConfig) -> None:
|
56
|
-
self.config_callback(config)
|
43
|
+
def log_file(self, name: str, contents: str) -> None:
|
44
|
+
self.file_callback(name, contents)
|
xax/task/loggers/state.py
CHANGED
@@ -3,8 +3,6 @@
|
|
3
3
|
from pathlib import Path
|
4
4
|
from typing import Literal
|
5
5
|
|
6
|
-
from omegaconf import DictConfig, OmegaConf
|
7
|
-
|
8
6
|
from xax.task.logger import LoggerImpl, LogLine
|
9
7
|
|
10
8
|
|
@@ -12,9 +10,6 @@ class StateLogger(LoggerImpl):
|
|
12
10
|
def __init__(
|
13
11
|
self,
|
14
12
|
run_directory: str | Path,
|
15
|
-
git_state_name: str = "git_state.txt",
|
16
|
-
train_code_name: str = "train_code.py",
|
17
|
-
config_name: str = "config.yaml",
|
18
13
|
flush_immediately: bool = False,
|
19
14
|
open_mode: Literal["w", "a"] = "w",
|
20
15
|
line_sep: str = "\n",
|
@@ -22,24 +17,16 @@ class StateLogger(LoggerImpl):
|
|
22
17
|
) -> None:
|
23
18
|
super().__init__(float("inf"))
|
24
19
|
|
25
|
-
self.
|
26
|
-
|
27
|
-
self.config_file = Path(run_directory).expanduser().resolve() / config_name
|
20
|
+
self.run_directory = Path(run_directory).expanduser().resolve()
|
21
|
+
|
28
22
|
self.flush_immediately = flush_immediately
|
29
23
|
self.open_mode = open_mode
|
30
24
|
self.line_sep = line_sep
|
31
25
|
self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
|
32
26
|
|
33
|
-
def
|
34
|
-
with open(self.
|
35
|
-
f.write(
|
36
|
-
|
37
|
-
def log_training_code(self, training_code: str) -> None:
|
38
|
-
with open(self.train_code_file, "w") as f:
|
39
|
-
f.write(training_code)
|
40
|
-
|
41
|
-
def log_config(self, config: DictConfig) -> None:
|
42
|
-
OmegaConf.save(config, self.config_file)
|
27
|
+
def log_file(self, name: str, contents: str) -> None:
|
28
|
+
with open(self.run_directory / name, "w") as f:
|
29
|
+
f.write(contents)
|
43
30
|
|
44
31
|
def write(self, line: LogLine) -> None:
|
45
32
|
pass
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -12,8 +12,6 @@ import time
|
|
12
12
|
from pathlib import Path
|
13
13
|
from typing import TypeVar
|
14
14
|
|
15
|
-
from omegaconf import DictConfig, OmegaConf
|
16
|
-
|
17
15
|
from xax.core.state import Phase
|
18
16
|
from xax.nn.parallel import is_master
|
19
17
|
from xax.task.logger import LoggerImpl, LogLine
|
@@ -60,10 +58,7 @@ class TensorboardLogger(LoggerImpl):
|
|
60
58
|
|
61
59
|
self.proc: subprocess.Popen | None = None
|
62
60
|
|
63
|
-
self.
|
64
|
-
self.training_code: str | None = None
|
65
|
-
self.config: DictConfig | None = None
|
66
|
-
|
61
|
+
self.files: dict[str, str] = {}
|
67
62
|
self.writers = TensorboardWriters(log_directory=self.log_directory, flush_seconds=flush_seconds)
|
68
63
|
self._started = False
|
69
64
|
|
@@ -158,20 +153,10 @@ class TensorboardLogger(LoggerImpl):
|
|
158
153
|
self._start()
|
159
154
|
return self.writers.writer(phase)
|
160
155
|
|
161
|
-
def
|
162
|
-
if not is_master():
|
163
|
-
return
|
164
|
-
self.git_state = f"```\n{git_state}\n```"
|
165
|
-
|
166
|
-
def log_training_code(self, training_code: str) -> None:
|
156
|
+
def log_file(self, name: str, contents: str) -> None:
|
167
157
|
if not is_master():
|
168
158
|
return
|
169
|
-
self.
|
170
|
-
|
171
|
-
def log_config(self, config: DictConfig) -> None:
|
172
|
-
if not is_master():
|
173
|
-
return
|
174
|
-
self.config = config
|
159
|
+
self.files[name] = f"```\n{contents}\n```"
|
175
160
|
|
176
161
|
def write(self, line: LogLine) -> None:
|
177
162
|
if not is_master():
|
@@ -210,14 +195,15 @@ class TensorboardLogger(LoggerImpl):
|
|
210
195
|
walltime=walltime,
|
211
196
|
)
|
212
197
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
198
|
+
for namespace, videos in line.videos.items():
|
199
|
+
for video_key, video_value in videos.items():
|
200
|
+
writer.add_video(
|
201
|
+
f"{namespace}/{video_key}",
|
202
|
+
video_value.frames,
|
203
|
+
fps=video_value.fps,
|
204
|
+
global_step=line.state.num_steps,
|
205
|
+
)
|
220
206
|
|
221
|
-
|
222
|
-
writer.add_text(
|
223
|
-
|
207
|
+
for name, contents in self.files.items():
|
208
|
+
writer.add_text(name, contents)
|
209
|
+
self.files.clear()
|
xax/task/mixins/__init__.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
4
4
|
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
5
|
+
from xax.task.mixins.compile import CompileConfig, CompileMixin
|
5
6
|
from xax.task.mixins.cpu_stats import CPUStatsConfig, CPUStatsMixin
|
6
7
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
7
8
|
from xax.task.mixins.gpu_stats import GPUStatsConfig, GPUStatsMixin
|
xax/task/mixins/artifacts.py
CHANGED
@@ -8,6 +8,8 @@ from dataclasses import dataclass
|
|
8
8
|
from pathlib import Path
|
9
9
|
from typing import Self, TypeVar
|
10
10
|
|
11
|
+
import jax
|
12
|
+
|
11
13
|
from xax.core.conf import field, get_run_dir
|
12
14
|
from xax.core.state import State
|
13
15
|
from xax.nn.parallel import is_master
|
@@ -19,6 +21,7 @@ from xax.utils.text import show_info
|
|
19
21
|
logger = logging.getLogger(__name__)
|
20
22
|
|
21
23
|
|
24
|
+
@jax.tree_util.register_dataclass
|
22
25
|
@dataclass
|
23
26
|
class ArtifactsConfig(BaseConfig):
|
24
27
|
exp_dir: str | None = field(None, help="The fixed experiment directory")
|
@@ -43,14 +46,14 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
43
46
|
run_dir = Path(task_file).resolve().parent
|
44
47
|
return run_dir / self.task_name
|
45
48
|
|
46
|
-
def set_exp_dir(self, exp_dir: Path) -> Self:
|
47
|
-
self._exp_dir = exp_dir
|
48
|
-
return self
|
49
|
-
|
50
49
|
@property
|
51
50
|
def exp_dir(self) -> Path:
|
52
51
|
return self.get_exp_dir()
|
53
52
|
|
53
|
+
def set_exp_dir(self, exp_dir: str | Path) -> Self:
|
54
|
+
self._exp_dir = Path(exp_dir).expanduser().resolve()
|
55
|
+
return self
|
56
|
+
|
54
57
|
def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
|
55
58
|
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
56
59
|
if not exists_ok:
|