xax 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/logger.py CHANGED
@@ -24,7 +24,6 @@ import jax
24
24
  import jax.numpy as jnp
25
25
  import numpy as np
26
26
  from jaxtyping import Array
27
- from omegaconf import DictConfig
28
27
  from PIL import Image, ImageDraw, ImageFont
29
28
  from PIL.Image import Image as PILImage
30
29
 
@@ -92,50 +91,6 @@ def make_human_viewable_resolution(
92
91
  return image.resize((new_width, new_height), interpolation)
93
92
 
94
93
 
95
- def image_with_text(
96
- image: PILImage,
97
- text: list[str],
98
- max_num_lines: int | None,
99
- line_spacing: int,
100
- centered: bool,
101
- ) -> PILImage:
102
- """Adds a text label to an image.
103
-
104
- Args:
105
- image: The image to label, with shape (C, H, W)
106
- text: The text label for the image
107
- max_num_lines: The number of lines of spacing to add to the bottom
108
- of the image
109
- line_spacing: The spacing between adjacent lines
110
- centered: If set, center the text labels, otherwise align to the left
111
-
112
- Returns:
113
- The image with a text label
114
- """
115
- if not text:
116
- return image
117
- if max_num_lines is None:
118
- max_num_lines = len(text)
119
- else:
120
- text = text[:max_num_lines]
121
- width, height = image.size
122
- font: ImageFont.ImageFont = ImageFont.load_default()
123
- _, _, _, line_height = font.getbbox(text[0])
124
- new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
125
- padded_image = Image.new(image.mode, (new_width, new_height), 255)
126
- padded_image.paste(image, (0, 0))
127
- drawer = ImageDraw.Draw(padded_image)
128
- for i, text_line in enumerate(text):
129
- text_line_top = height + line_spacing + i * (line_height + line_spacing)
130
- if centered:
131
- _, _, line_width, _ = font.getbbox(text_line)
132
- text_line_left = (width - line_width) / 2
133
- drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
134
- else:
135
- drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
136
- return padded_image
137
-
138
-
139
94
  class namespace_context: # noqa: N801
140
95
  def __init__(self, name: str | None) -> None:
141
96
  self._name = name
@@ -250,7 +205,68 @@ def as_numpy(array: Array) -> np.ndarray:
250
205
  return np.array(array)
251
206
 
252
207
 
253
- def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int, int] | None = None) -> PILImage:
208
+ @dataclass(kw_only=True)
209
+ class LogImage:
210
+ image: PILImage
211
+
212
+
213
+ @dataclass(kw_only=True)
214
+ class LogVideo:
215
+ """Container for video data and metadata.
216
+
217
+ Attributes:
218
+ frames: Video frames as a numpy array of shape (T,H,W,C)
219
+ fps: Frames per second
220
+ """
221
+
222
+ frames: np.ndarray
223
+ fps: int
224
+
225
+
226
+ @dataclass(kw_only=True)
227
+ class LogLine:
228
+ state: State
229
+ scalars: dict[str, dict[str, Number]]
230
+ strings: dict[str, dict[str, str]]
231
+ images: dict[str, dict[str, LogImage]]
232
+ videos: dict[str, dict[str, LogVideo]]
233
+
234
+
235
+ @dataclass(kw_only=True)
236
+ class LogErrorSummary:
237
+ message: str
238
+
239
+
240
+ @dataclass(kw_only=True)
241
+ class LogError:
242
+ message: str
243
+ location: str | None = None
244
+
245
+ @property
246
+ def message_with_location(self) -> str:
247
+ message = self.message
248
+ if self.location is not None:
249
+ message += f" ({self.location})"
250
+ return message
251
+
252
+
253
+ @dataclass(kw_only=True)
254
+ class LogStatus:
255
+ message: str
256
+ created: float
257
+ filename: str | None = None
258
+ lineno: int | None = None
259
+
260
+
261
+ @dataclass(kw_only=True)
262
+ class LogPing:
263
+ message: str
264
+ created: float
265
+ filename: str | None = None
266
+ lineno: int | None = None
267
+
268
+
269
+ def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int, int] | None = None) -> LogImage:
254
270
  if not isinstance(image, (np.ndarray, Array, PILImage)):
255
271
  raise ValueError(f"Unsupported image type: {type(image)}")
256
272
  if isinstance(image, Array):
@@ -283,54 +299,85 @@ def get_image(image: np.ndarray | Array | PILImage, target_resolution: tuple[int
283
299
 
284
300
  if target_resolution is not None:
285
301
  image = make_human_viewable_resolution(image, trg_res=target_resolution)
286
- return image
302
+ return LogImage(image=image)
287
303
 
288
304
 
289
- @dataclass
290
- class LogImage:
291
- image: PILImage
305
+ def image_with_text(
306
+ image: PILImage,
307
+ text: list[str],
308
+ max_num_lines: int | None,
309
+ line_spacing: int,
310
+ centered: bool,
311
+ ) -> LogImage:
312
+ """Adds a text label to an image.
292
313
 
314
+ Args:
315
+ image: The image to label, with shape (C, H, W)
316
+ text: The text label for the image
317
+ max_num_lines: The number of lines of spacing to add to the bottom
318
+ of the image
319
+ line_spacing: The spacing between adjacent lines
320
+ centered: If set, center the text labels, otherwise align to the left
293
321
 
294
- @dataclass
295
- class LogLine:
296
- state: State
297
- scalars: dict[str, dict[str, Number]]
298
- strings: dict[str, dict[str, str]]
299
- images: dict[str, dict[str, LogImage]]
322
+ Returns:
323
+ The image with a text label
324
+ """
325
+ if not text:
326
+ return LogImage(image=image)
327
+ if max_num_lines is None:
328
+ max_num_lines = len(text)
329
+ else:
330
+ text = text[:max_num_lines]
331
+ width, height = image.size
332
+ font: ImageFont.ImageFont = ImageFont.load_default()
333
+ _, _, _, line_height = font.getbbox(text[0])
334
+ new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
335
+ padded_image = Image.new(image.mode, (new_width, new_height), 255)
336
+ padded_image.paste(image, (0, 0))
337
+ drawer = ImageDraw.Draw(padded_image)
338
+ for i, text_line in enumerate(text):
339
+ text_line_top = height + line_spacing + i * (line_height + line_spacing)
340
+ if centered:
341
+ _, _, line_width, _ = font.getbbox(text_line)
342
+ text_line_left = (width - line_width) / 2
343
+ drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
344
+ else:
345
+ drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
346
+ return LogImage(image=padded_image)
300
347
 
301
348
 
302
- @dataclass
303
- class LogErrorSummary:
304
- message: str
349
+ def get_video(video: np.ndarray | Array, fps: int = 30) -> LogVideo:
350
+ """Converts video data to standard format.
305
351
 
352
+ Args:
353
+ video: The video frames. Can be:
354
+ - A numpy array of shape (T, H, W, C) or (T, C, H, W)
355
+ - A JAX array of shape (T, H, W, C) or (T, C, H, W)
356
+ fps: Frames per second
306
357
 
307
- @dataclass
308
- class LogError:
309
- message: str
310
- location: str | None = None
358
+ Returns:
359
+ LogVideo containing standardized video frames
360
+ """
361
+ if isinstance(video, Array):
362
+ video = as_numpy(video)
311
363
 
312
- @property
313
- def message_with_location(self) -> str:
314
- message = self.message
315
- if self.location is not None:
316
- message += f" ({self.location})"
317
- return message
364
+ if not isinstance(video, np.ndarray):
365
+ raise ValueError(f"Unsupported video type: {type(video)}")
318
366
 
367
+ # Handle different dimension orderings
368
+ if video.ndim != 4:
369
+ raise ValueError(f"Expected video array of shape (T, H, W, C) or (T, C, H, W), got shape {video.shape}")
319
370
 
320
- @dataclass
321
- class LogStatus:
322
- message: str
323
- created: float
324
- filename: str | None = None
325
- lineno: int | None = None
371
+ if video.shape[1] == 3: # (T,C,H,W) format
372
+ video = video.transpose(0, 2, 3, 1)
326
373
 
374
+ # Normalize and convert to uint8 if needed
375
+ if np.issubdtype(video.dtype, np.floating):
376
+ video = (normalize(video) * 255).round().astype(np.uint8)
377
+ elif video.dtype != np.uint8:
378
+ raise ValueError(f"Unsupported video dtype: {video.dtype}")
327
379
 
328
- @dataclass
329
- class LogPing:
330
- message: str
331
- created: float
332
- filename: str | None = None
333
- lineno: int | None = None
380
+ return LogVideo(frames=video, fps=fps)
334
381
 
335
382
 
336
383
  class LoggerImpl(ABC):
@@ -391,25 +438,12 @@ class LoggerImpl(ABC):
391
438
  ping: The ping to write.
392
439
  """
393
440
 
394
- def log_git_state(self, git_state: str) -> None:
395
- """Logs Git state for the current run.
396
-
397
- Args:
398
- git_state: The Git state, as text blocks.
399
- """
400
-
401
- def log_training_code(self, training_code: str) -> None:
402
- """Logs the training script code.
403
-
404
- Args:
405
- training_code: The training script code.
406
- """
407
-
408
- def log_config(self, config: DictConfig) -> None:
409
- """Logs the configuration for the current run.
441
+ def log_file(self, name: str, contents: str) -> None:
442
+ """Logs a large text file.
410
443
 
411
444
  Args:
412
- config: The configuration, as a DictConfig.
445
+ name: The name of the file.
446
+ contents: The contents of the file.
413
447
  """
414
448
 
415
449
  def should_log(self, state: State) -> bool:
@@ -464,7 +498,8 @@ class Logger:
464
498
  def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
465
499
  self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
466
500
  self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
467
- self.images: dict[str, dict[str, Callable[[], PILImage]]] = defaultdict(dict)
501
+ self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
502
+ self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
468
503
  self.default_namespace = default_namespace
469
504
  self.loggers: list[LoggerImpl] = []
470
505
 
@@ -488,13 +523,15 @@ class Logger:
488
523
  state=state,
489
524
  scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
490
525
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
491
- images={k: {kk: LogImage(v()) for kk, v in v.items()} for k, v in self.images.items()},
526
+ images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
527
+ videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
492
528
  )
493
529
 
494
530
  def clear(self) -> None:
495
531
  self.scalars.clear()
496
532
  self.strings.clear()
497
533
  self.images.clear()
534
+ self.videos.clear()
498
535
 
499
536
  def write(self, state: State) -> None:
500
537
  """Writes the current step's logging information.
@@ -513,11 +550,11 @@ class Logger:
513
550
 
514
551
  def write_error_summary(self, error_summary: str) -> None:
515
552
  for logger in self.loggers:
516
- logger.write_error_summary(LogErrorSummary(error_summary))
553
+ logger.write_error_summary(LogErrorSummary(message=error_summary))
517
554
 
518
555
  def write_error(self, message: str, location: str | None = None) -> None:
519
556
  for logger in self.loggers:
520
- logger.write_error(LogError(message, location))
557
+ logger.write_error(LogError(message=message, location=location))
521
558
 
522
559
  def write_status(
523
560
  self,
@@ -526,7 +563,12 @@ class Logger:
526
563
  lineno: int | None = None,
527
564
  created: float | None = None,
528
565
  ) -> None:
529
- status = LogStatus(message, time.time() if created is None else created, filename, lineno)
566
+ status = LogStatus(
567
+ message=message,
568
+ created=time.time() if created is None else created,
569
+ filename=filename,
570
+ lineno=lineno,
571
+ )
530
572
  for logger in self.loggers:
531
573
  logger.write_status(status)
532
574
 
@@ -537,7 +579,12 @@ class Logger:
537
579
  lineno: int | None = None,
538
580
  created: float | None = None,
539
581
  ) -> None:
540
- ping = LogPing(message, time.time() if created is None else created, filename, lineno)
582
+ ping = LogPing(
583
+ message=message,
584
+ created=time.time() if created is None else created,
585
+ filename=filename,
586
+ lineno=lineno,
587
+ )
541
588
  for logger in self.loggers:
542
589
  logger.write_ping(ping)
543
590
 
@@ -556,6 +603,9 @@ class Logger:
556
603
  raise RuntimeError("The logger is not active")
557
604
  namespace = self.resolve_namespace(namespace)
558
605
 
606
+ if isinstance(value, jnp.ndarray):
607
+ assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
608
+
559
609
  @functools.lru_cache(maxsize=None)
560
610
  def scalar_future() -> Number:
561
611
  return value() if callable(value) else value
@@ -602,7 +652,7 @@ class Logger:
602
652
  namespace = self.resolve_namespace(namespace)
603
653
 
604
654
  @functools.lru_cache(maxsize=None)
605
- def image_future() -> PILImage:
655
+ def image_future() -> LogImage:
606
656
  return get_image(value() if callable(value) else value, target_resolution)
607
657
 
608
658
  self.images[namespace][key] = image_future
@@ -638,11 +688,11 @@ class Logger:
638
688
  namespace = self.resolve_namespace(namespace)
639
689
 
640
690
  @functools.lru_cache(maxsize=None)
641
- def image_future() -> PILImage:
691
+ def image_future() -> LogImage:
642
692
  image, label = value() if callable(value) else value
643
693
  image = get_image(image, target_resolution)
644
694
  return image_with_text(
645
- image,
695
+ image.image,
646
696
  standardize_text(label, max_line_length),
647
697
  max_num_lines=max_num_lines,
648
698
  line_spacing=line_spacing,
@@ -685,7 +735,7 @@ class Logger:
685
735
  namespace = self.resolve_namespace(namespace)
686
736
 
687
737
  @functools.lru_cache(maxsize=None)
688
- def images_future() -> PILImage:
738
+ def images_future() -> LogImage:
689
739
  images = value() if callable(value) else value
690
740
  if max_images is not None:
691
741
  images = images[:max_images]
@@ -694,7 +744,8 @@ class Logger:
694
744
  if isinstance(images, Sequence):
695
745
  images = list(images)
696
746
  images = [get_image(image, target_resolution) for image in images]
697
- return tile_images(images, sep)
747
+ tiled = tile_images([img.image for img in images], sep)
748
+ return LogImage(image=tiled)
698
749
 
699
750
  self.images[namespace][key] = images_future
700
751
 
@@ -739,37 +790,58 @@ class Logger:
739
790
  namespace = self.resolve_namespace(namespace)
740
791
 
741
792
  @functools.lru_cache(maxsize=None)
742
- def images_future() -> PILImage:
793
+ def images_future() -> LogImage:
743
794
  images, labels = value() if callable(value) else value
744
795
  if max_images is not None:
745
796
  images = images[:max_images]
746
797
  labels = labels[:max_images]
747
798
  images = [get_image(image, target_resolution) for image in images]
748
- images = [
799
+ labeled = [
749
800
  image_with_text(
750
- image,
801
+ img.image,
751
802
  standardize_text(label, max_line_length),
752
803
  max_num_lines=max_num_lines,
753
804
  line_spacing=line_spacing,
754
805
  centered=centered,
755
806
  )
756
- for image, label in zip(images, labels)
807
+ for img, label in zip(images, labels)
757
808
  ]
758
- return tile_images(images, sep)
809
+ tiled = tile_images([img.image for img in labeled], sep)
810
+ return LogImage(image=tiled)
759
811
 
760
812
  self.images[namespace][key] = images_future
761
813
 
762
- def log_git_state(self, git_state: str) -> None:
814
+ def log_file(self, name: str, contents: str) -> None:
763
815
  for logger in self.loggers:
764
- logger.log_git_state(git_state)
816
+ logger.log_file(name, contents)
765
817
 
766
- def log_training_code(self, training_code: str) -> None:
767
- for logger in self.loggers:
768
- logger.log_training_code(training_code)
818
+ def log_video(
819
+ self,
820
+ key: str,
821
+ value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
822
+ *,
823
+ fps: int = 30,
824
+ namespace: str | None = None,
825
+ ) -> None:
826
+ """Logs a video.
769
827
 
770
- def log_config(self, config: DictConfig) -> None:
771
- for logger in self.loggers:
772
- logger.log_config(config)
828
+ Args:
829
+ key: The key being logged
830
+ value: The video frames. Can be:
831
+ - A numpy array of shape (T,H,W,C) or (T,C,H,W)
832
+ - A JAX array of shape (T,H,W,C) or (T,C,H,W)
833
+ fps: Frames per second
834
+ namespace: An optional logging namespace
835
+ """
836
+ if not self.active:
837
+ raise RuntimeError("The logger is not active")
838
+ namespace = self.resolve_namespace(namespace)
839
+
840
+ @functools.lru_cache(maxsize=None)
841
+ def video_future() -> LogVideo:
842
+ return get_video(value() if callable(value) else value, fps=fps)
843
+
844
+ self.videos[namespace][key] = video_future
773
845
 
774
846
  def __enter__(self) -> Self:
775
847
  self.active = True
@@ -2,8 +2,6 @@
2
2
 
3
3
  from typing import Callable
4
4
 
5
- from omegaconf import DictConfig
6
-
7
5
  from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
8
6
 
9
7
 
@@ -16,9 +14,7 @@ class CallbackLogger(LoggerImpl):
16
14
  error_callback: Callable[[LogError], None] = lambda x: None,
17
15
  status_callback: Callable[[LogStatus], None] = lambda x: None,
18
16
  ping_callback: Callable[[LogPing], None] = lambda x: None,
19
- git_state_callback: Callable[[str], None] = lambda x: None,
20
- training_code_callback: Callable[[str], None] = lambda x: None,
21
- config_callback: Callable[[DictConfig], None] = lambda x: None,
17
+ file_callback: Callable[[str, str], None] = lambda x, y: None,
22
18
  ) -> None:
23
19
  super().__init__()
24
20
 
@@ -27,9 +23,7 @@ class CallbackLogger(LoggerImpl):
27
23
  self.error_callback = error_callback
28
24
  self.status_callback = status_callback
29
25
  self.ping_callback = ping_callback
30
- self.git_state_callback = git_state_callback
31
- self.training_code_callback = training_code_callback
32
- self.config_callback = config_callback
26
+ self.file_callback = file_callback
33
27
 
34
28
  def write(self, line: LogLine) -> None:
35
29
  self.callback(line)
@@ -46,11 +40,5 @@ class CallbackLogger(LoggerImpl):
46
40
  def write_ping(self, ping: LogPing) -> None:
47
41
  self.ping_callback(ping)
48
42
 
49
- def log_git_state(self, git_state: str) -> None:
50
- self.git_state_callback(git_state)
51
-
52
- def log_training_code(self, training_code: str) -> None:
53
- self.training_code_callback(training_code)
54
-
55
- def log_config(self, config: DictConfig) -> None:
56
- self.config_callback(config)
43
+ def log_file(self, name: str, contents: str) -> None:
44
+ self.file_callback(name, contents)
xax/task/loggers/state.py CHANGED
@@ -3,8 +3,6 @@
3
3
  from pathlib import Path
4
4
  from typing import Literal
5
5
 
6
- from omegaconf import DictConfig, OmegaConf
7
-
8
6
  from xax.task.logger import LoggerImpl, LogLine
9
7
 
10
8
 
@@ -12,9 +10,6 @@ class StateLogger(LoggerImpl):
12
10
  def __init__(
13
11
  self,
14
12
  run_directory: str | Path,
15
- git_state_name: str = "git_state.txt",
16
- train_code_name: str = "train_code.py",
17
- config_name: str = "config.yaml",
18
13
  flush_immediately: bool = False,
19
14
  open_mode: Literal["w", "a"] = "w",
20
15
  line_sep: str = "\n",
@@ -22,24 +17,16 @@ class StateLogger(LoggerImpl):
22
17
  ) -> None:
23
18
  super().__init__(float("inf"))
24
19
 
25
- self.git_state_file = Path(run_directory).expanduser().resolve() / git_state_name
26
- self.train_code_file = Path(run_directory).expanduser().resolve() / train_code_name
27
- self.config_file = Path(run_directory).expanduser().resolve() / config_name
20
+ self.run_directory = Path(run_directory).expanduser().resolve()
21
+
28
22
  self.flush_immediately = flush_immediately
29
23
  self.open_mode = open_mode
30
24
  self.line_sep = line_sep
31
25
  self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
32
26
 
33
- def log_git_state(self, git_state: str) -> None:
34
- with open(self.git_state_file, "w") as f:
35
- f.write(git_state)
36
-
37
- def log_training_code(self, training_code: str) -> None:
38
- with open(self.train_code_file, "w") as f:
39
- f.write(training_code)
40
-
41
- def log_config(self, config: DictConfig) -> None:
42
- OmegaConf.save(config, self.config_file)
27
+ def log_file(self, name: str, contents: str) -> None:
28
+ with open(self.run_directory / name, "w") as f:
29
+ f.write(contents)
43
30
 
44
31
  def write(self, line: LogLine) -> None:
45
32
  pass
@@ -12,8 +12,6 @@ import time
12
12
  from pathlib import Path
13
13
  from typing import TypeVar
14
14
 
15
- from omegaconf import DictConfig, OmegaConf
16
-
17
15
  from xax.core.state import Phase
18
16
  from xax.nn.parallel import is_master
19
17
  from xax.task.logger import LoggerImpl, LogLine
@@ -60,10 +58,7 @@ class TensorboardLogger(LoggerImpl):
60
58
 
61
59
  self.proc: subprocess.Popen | None = None
62
60
 
63
- self.git_state: str | None = None
64
- self.training_code: str | None = None
65
- self.config: DictConfig | None = None
66
-
61
+ self.files: dict[str, str] = {}
67
62
  self.writers = TensorboardWriters(log_directory=self.log_directory, flush_seconds=flush_seconds)
68
63
  self._started = False
69
64
 
@@ -158,20 +153,10 @@ class TensorboardLogger(LoggerImpl):
158
153
  self._start()
159
154
  return self.writers.writer(phase)
160
155
 
161
- def log_git_state(self, git_state: str) -> None:
162
- if not is_master():
163
- return
164
- self.git_state = f"```\n{git_state}\n```"
165
-
166
- def log_training_code(self, training_code: str) -> None:
156
+ def log_file(self, name: str, contents: str) -> None:
167
157
  if not is_master():
168
158
  return
169
- self.training_code = f"```python\n{training_code}\n```"
170
-
171
- def log_config(self, config: DictConfig) -> None:
172
- if not is_master():
173
- return
174
- self.config = config
159
+ self.files[name] = f"```\n{contents}\n```"
175
160
 
176
161
  def write(self, line: LogLine) -> None:
177
162
  if not is_master():
@@ -210,14 +195,15 @@ class TensorboardLogger(LoggerImpl):
210
195
  walltime=walltime,
211
196
  )
212
197
 
213
- if self.config is not None:
214
- writer.add_text("config", f"```\n{OmegaConf.to_yaml(self.config)}\n```")
215
- self.config = None
216
-
217
- if self.git_state is not None:
218
- writer.add_text("git", self.git_state)
219
- self.git_state = None
198
+ for namespace, videos in line.videos.items():
199
+ for video_key, video_value in videos.items():
200
+ writer.add_video(
201
+ f"{namespace}/{video_key}",
202
+ video_value.frames,
203
+ fps=video_value.fps,
204
+ global_step=line.state.num_steps,
205
+ )
220
206
 
221
- if self.training_code is not None:
222
- writer.add_text("code", self.training_code)
223
- self.training_code = None
207
+ for name, contents in self.files.items():
208
+ writer.add_text(name, contents)
209
+ self.files.clear()
@@ -2,6 +2,7 @@
2
2
 
3
3
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
4
4
  from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
5
+ from xax.task.mixins.compile import CompileConfig, CompileMixin
5
6
  from xax.task.mixins.cpu_stats import CPUStatsConfig, CPUStatsMixin
6
7
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
7
8
  from xax.task.mixins.gpu_stats import GPUStatsConfig, GPUStatsMixin
@@ -8,6 +8,8 @@ from dataclasses import dataclass
8
8
  from pathlib import Path
9
9
  from typing import Self, TypeVar
10
10
 
11
+ import jax
12
+
11
13
  from xax.core.conf import field, get_run_dir
12
14
  from xax.core.state import State
13
15
  from xax.nn.parallel import is_master
@@ -19,6 +21,7 @@ from xax.utils.text import show_info
19
21
  logger = logging.getLogger(__name__)
20
22
 
21
23
 
24
+ @jax.tree_util.register_dataclass
22
25
  @dataclass
23
26
  class ArtifactsConfig(BaseConfig):
24
27
  exp_dir: str | None = field(None, help="The fixed experiment directory")
@@ -43,14 +46,14 @@ class ArtifactsMixin(BaseTask[Config]):
43
46
  run_dir = Path(task_file).resolve().parent
44
47
  return run_dir / self.task_name
45
48
 
46
- def set_exp_dir(self, exp_dir: Path) -> Self:
47
- self._exp_dir = exp_dir
48
- return self
49
-
50
49
  @property
51
50
  def exp_dir(self) -> Path:
52
51
  return self.get_exp_dir()
53
52
 
53
+ def set_exp_dir(self, exp_dir: str | Path) -> Self:
54
+ self._exp_dir = Path(exp_dir).expanduser().resolve()
55
+ return self
56
+
54
57
  def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
55
58
  if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
56
59
  if not exists_ok: