xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
File without changes
@@ -0,0 +1,56 @@
1
+ """Defines a logger that calls a callback function with the log line."""
2
+
3
+ from typing import Callable
4
+
5
+ from omegaconf import DictConfig
6
+
7
+ from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
8
+
9
+
10
+ class CallbackLogger(LoggerImpl):
11
+ def __init__(
12
+ self,
13
+ *,
14
+ callback: Callable[[LogLine], None] = lambda x: None,
15
+ error_summary_callback: Callable[[LogErrorSummary], None] = lambda x: None,
16
+ error_callback: Callable[[LogError], None] = lambda x: None,
17
+ status_callback: Callable[[LogStatus], None] = lambda x: None,
18
+ ping_callback: Callable[[LogPing], None] = lambda x: None,
19
+ git_state_callback: Callable[[str], None] = lambda x: None,
20
+ training_code_callback: Callable[[str], None] = lambda x: None,
21
+ config_callback: Callable[[DictConfig], None] = lambda x: None,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.callback = callback
26
+ self.error_summary_callback = error_summary_callback
27
+ self.error_callback = error_callback
28
+ self.status_callback = status_callback
29
+ self.ping_callback = ping_callback
30
+ self.git_state_callback = git_state_callback
31
+ self.training_code_callback = training_code_callback
32
+ self.config_callback = config_callback
33
+
34
+ def write(self, line: LogLine) -> None:
35
+ self.callback(line)
36
+
37
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
38
+ self.error_summary_callback(error_summary)
39
+
40
+ def write_error(self, error: LogError) -> None:
41
+ self.error_callback(error)
42
+
43
+ def write_status(self, status: LogStatus) -> None:
44
+ self.status_callback(status)
45
+
46
+ def write_ping(self, ping: LogPing) -> None:
47
+ self.ping_callback(ping)
48
+
49
+ def log_git_state(self, git_state: str) -> None:
50
+ self.git_state_callback(git_state)
51
+
52
+ def log_training_code(self, training_code: str) -> None:
53
+ self.training_code_callback(training_code)
54
+
55
+ def log_config(self, config: DictConfig) -> None:
56
+ self.config_callback(config)
@@ -0,0 +1,121 @@
1
+ """Defines a logger which logs JSON lines to a file."""
2
+
3
+ import json
4
+ import sys
5
+ from dataclasses import asdict
6
+ from typing import Any, Literal, TextIO
7
+
8
+ from jaxtyping import Array
9
+
10
+ from xax.task.logger import LogError, LoggerImpl, LogLine, LogPing, LogStatus
11
+
12
+
13
+ def get_json_value(value: Any) -> Any: # noqa: ANN401
14
+ if isinstance(value, Array):
15
+ value = value.item()
16
+ return value
17
+
18
+
19
+ class JsonLogger(LoggerImpl):
20
+ def __init__(
21
+ self,
22
+ log_stream: TextIO = sys.stdout,
23
+ err_log_stream: TextIO = sys.stderr,
24
+ flush_immediately: bool = False,
25
+ open_mode: Literal["w", "a"] = "w",
26
+ line_sep: str = "\n",
27
+ remove_unicode_from_namespaces: bool = True,
28
+ log_interval_seconds: float = 10.0,
29
+ ) -> None:
30
+ """Defines a simpler logger which logs to stdout.
31
+
32
+ Args:
33
+ log_stream: The stream to log to.
34
+ err_log_stream: The stream to log errors to.
35
+ flush_immediately: Whether to flush the file after every write.
36
+ open_mode: The file open mode.
37
+ line_sep: The line separator to use.
38
+ remove_unicode_from_namespaces: Whether to remove unicode from
39
+ namespaces. This is the typical behavior for namespaces that
40
+ use ASCII art for visibility in other logs, but in the JSON
41
+ log file should be ignored.
42
+ log_interval_seconds: The interval between successive log lines.
43
+ """
44
+ super().__init__(log_interval_seconds)
45
+
46
+ self.log_stream = log_stream
47
+ self.err_log_stream = err_log_stream
48
+ self.flush_immediately = flush_immediately
49
+ self.open_mode = open_mode
50
+ self.line_sep = line_sep
51
+ self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
52
+
53
+ @property
54
+ def fp(self) -> TextIO:
55
+ return self.log_stream
56
+
57
+ @property
58
+ def err_fp(self) -> TextIO:
59
+ return self.err_log_stream
60
+
61
+ def get_json(self, line: LogLine) -> str:
62
+ data: dict = {"state": asdict(line.state)}
63
+
64
+ def add_logs(log: dict[str, dict[str, Any]], data: dict) -> None:
65
+ for namespace, values in log.items():
66
+ if self.remove_unicode_from_namespaces:
67
+ namespace = namespace.encode("ascii", errors="ignore").decode("ascii").strip()
68
+ if namespace not in data:
69
+ data[namespace] = {}
70
+ for k, v in values.items():
71
+ data[namespace][k] = get_json_value(v)
72
+
73
+ add_logs(line.scalars, data)
74
+ add_logs(line.strings, data)
75
+ return json.dumps(data)
76
+
77
+ def write(self, line: LogLine) -> None:
78
+ self.fp.write(self.get_json(line))
79
+ self.fp.write(self.line_sep)
80
+ if self.flush_immediately:
81
+ self.fp.flush()
82
+
83
+ def write_error(self, error: LogError) -> None:
84
+ self.err_fp.write(error.message)
85
+ if error.location is not None:
86
+ self.err_fp.write(f" ({error.location})")
87
+ self.err_fp.write(self.line_sep)
88
+ if self.flush_immediately:
89
+ self.err_fp.flush()
90
+
91
+ def write_ping(self, ping: LogPing) -> None:
92
+ self.fp.write(
93
+ json.dumps(
94
+ {
95
+ "ping": {
96
+ "message": ping.message,
97
+ "filename": ping.filename,
98
+ "lineno": ping.lineno,
99
+ },
100
+ },
101
+ ),
102
+ )
103
+ self.fp.write(self.line_sep)
104
+ if self.flush_immediately:
105
+ self.fp.flush()
106
+
107
+ def write_status(self, status: LogStatus) -> None:
108
+ self.fp.write(
109
+ json.dumps(
110
+ {
111
+ "status": {
112
+ "message": status.message,
113
+ "filename": status.filename,
114
+ "lineno": status.lineno,
115
+ },
116
+ },
117
+ ),
118
+ )
119
+ self.fp.write(self.line_sep)
120
+ if self.flush_immediately:
121
+ self.fp.flush()
@@ -0,0 +1,45 @@
1
+ """Defines a logger which logs the current training state."""
2
+
3
+ from pathlib import Path
4
+ from typing import Literal
5
+
6
+ from omegaconf import DictConfig, OmegaConf
7
+
8
+ from xax.task.logger import LoggerImpl, LogLine
9
+
10
+
11
+ class StateLogger(LoggerImpl):
12
+ def __init__(
13
+ self,
14
+ run_directory: str | Path,
15
+ git_state_name: str = "git_state.txt",
16
+ train_code_name: str = "train_code.py",
17
+ config_name: str = "config.yaml",
18
+ flush_immediately: bool = False,
19
+ open_mode: Literal["w", "a"] = "w",
20
+ line_sep: str = "\n",
21
+ remove_unicode_from_namespaces: bool = True,
22
+ ) -> None:
23
+ super().__init__(float("inf"))
24
+
25
+ self.git_state_file = Path(run_directory).expanduser().resolve() / git_state_name
26
+ self.train_code_file = Path(run_directory).expanduser().resolve() / train_code_name
27
+ self.config_file = Path(run_directory).expanduser().resolve() / config_name
28
+ self.flush_immediately = flush_immediately
29
+ self.open_mode = open_mode
30
+ self.line_sep = line_sep
31
+ self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
32
+
33
+ def log_git_state(self, git_state: str) -> None:
34
+ with open(self.git_state_file, "w") as f:
35
+ f.write(git_state)
36
+
37
+ def log_training_code(self, training_code: str) -> None:
38
+ with open(self.train_code_file, "w") as f:
39
+ f.write(training_code)
40
+
41
+ def log_config(self, config: DictConfig) -> None:
42
+ OmegaConf.save(config, self.config_file)
43
+
44
+ def write(self, line: LogLine) -> None:
45
+ pass
@@ -0,0 +1,170 @@
1
+ """Defines a logger that logs to stdout."""
2
+
3
+ import datetime
4
+ import logging
5
+ import sys
6
+ from collections import deque
7
+ from typing import Any, Deque, TextIO
8
+
9
+ from jaxtyping import Array
10
+
11
+ from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
12
+ from xax.utils.text import Color, colored, format_timedelta
13
+
14
+
15
+ def format_number(value: int | float, precision: int) -> str:
16
+ if isinstance(value, int):
17
+ return str(value)
18
+ return f"{value:.{precision}g}"
19
+
20
+
21
+ def as_str(value: Any, precision: int) -> str: # noqa: ANN401
22
+ if isinstance(value, str):
23
+ return f'"{value}"'
24
+ if isinstance(value, Array):
25
+ value = value.item()
26
+ if isinstance(value, (int, float)):
27
+ return format_number(value, precision)
28
+ raise TypeError(f"Unexpected log type: {type(value)}")
29
+
30
+
31
+ class StdoutLogger(LoggerImpl):
32
+ def __init__(
33
+ self,
34
+ write_fp: TextIO = sys.stdout,
35
+ precision: int = 4,
36
+ log_timers: bool = False,
37
+ log_perf: bool = False,
38
+ log_optim: bool = False,
39
+ log_fp: bool = False,
40
+ log_interval_seconds: float = 1.0,
41
+ remove_temporary_after: datetime.timedelta = datetime.timedelta(seconds=10),
42
+ ) -> None:
43
+ """Defines a logger which shows a pop-up using Curses.
44
+
45
+ Args:
46
+ write_fp: The file to write logs to.
47
+ precision: The integer precision to use when logging scalars.
48
+ log_timers: Whether to log timers.
49
+ log_perf: Whether to log performance metrics.
50
+ log_optim: Whether to log optimizer parameters.
51
+ log_fp: Whether to log floating point parameters.
52
+ log_interval_seconds: The interval between successive log lines.
53
+ remove_temporary_after: The time after which temporary toasts
54
+ are removed.
55
+ """
56
+ super().__init__(log_interval_seconds)
57
+
58
+ self.write_fp = write_fp
59
+ self.log_timers = log_timers
60
+ self.log_perf = log_perf
61
+ self.log_fp = log_fp
62
+ self.log_optim = log_optim
63
+ self.precision = precision
64
+ self.remove_temporary_after = remove_temporary_after
65
+ self.logger = logging.getLogger("stdout")
66
+
67
+ self.statuses: Deque[tuple[str, datetime.datetime]] = deque()
68
+ self.pings: Deque[tuple[str, datetime.datetime]] = deque()
69
+ self.errors: Deque[tuple[str, datetime.datetime]] = deque()
70
+ self.error_summary: tuple[str, datetime.datetime] | None = None
71
+
72
+ def start(self) -> None:
73
+ return super().start()
74
+
75
+ def stop(self) -> None:
76
+ self.write_queues()
77
+ return super().stop()
78
+
79
+ def write_separator(self) -> None:
80
+ self.write_fp.write("\033[2J\033[H")
81
+
82
+ def write_state_window(self, line: LogLine) -> None:
83
+ elapsed_time = format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True)
84
+ state_info = {
85
+ "Steps": f"{line.state.num_steps}",
86
+ "Samples": f"{line.state.num_samples}",
87
+ "Elapsed Time": f"{elapsed_time}",
88
+ }
89
+
90
+ colored_prefix = colored("Phase: ", "grey", bold=True)
91
+ colored_phase = colored(line.state.phase, "green" if line.state.phase == "train" else "yellow", bold=True)
92
+ self.write_fp.write(f"{colored_prefix}{colored_phase}\n")
93
+ for k, v in state_info.items():
94
+ self.write_fp.write(f" ↪ {k}: {colored(v, 'cyan')}\n")
95
+
96
+ def write_log_window(self, line: LogLine) -> None:
97
+ namespace_to_lines: dict[str, dict[str, str]] = {}
98
+
99
+ def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
100
+ for namespace, values in log.items():
101
+ if not self.log_timers and namespace.startswith("⏰"):
102
+ continue
103
+ if not self.log_perf and namespace.startswith("🔧"):
104
+ continue
105
+ if not self.log_optim and namespace.startswith("📉"):
106
+ continue
107
+ if not self.log_fp and namespace.startswith("⚖️"):
108
+ continue
109
+ if namespace not in namespace_to_lines:
110
+ namespace_to_lines[namespace] = {}
111
+ for k, v in values.items():
112
+ v_str = as_str(v, self.precision)
113
+ namespace_to_lines[namespace][k] = v_str
114
+
115
+ add_logs(line.scalars, namespace_to_lines)
116
+ add_logs(line.strings, namespace_to_lines)
117
+ if not namespace_to_lines:
118
+ return
119
+
120
+ self.write_fp.write("\n")
121
+ for namespace, lines in sorted(namespace_to_lines.items()):
122
+ self.write_fp.write(f"{colored(namespace, 'cyan', bold=True)}\n")
123
+ for k, v in lines.items():
124
+ self.write_fp.write(f" ↪ {k}: {v}\n")
125
+
126
+ def write_queue(self, title: str, q: Deque[tuple[str, datetime.datetime]], remove: bool, color: Color) -> None:
127
+ if not q:
128
+ return
129
+
130
+ self.write_fp.write(f"\n{colored(title, 'grey', bold=True)}\n")
131
+ self.write_fp.write("\n".join(f" ✦ {colored(msg, color)}" for msg, _ in reversed(q)))
132
+ self.write_fp.write("\n")
133
+
134
+ if remove:
135
+ now = datetime.datetime.now()
136
+ while q and now - q[0][1] > self.remove_temporary_after:
137
+ q.popleft()
138
+
139
+ def write_queues(self) -> None:
140
+ self.write_queue("Status", self.statuses, False, "green")
141
+ self.write_queue("Pings", self.pings, True, "cyan")
142
+ self.write_queue("Errors", self.errors, False, "red")
143
+
144
+ def write_error_summary_to_screen(self) -> None:
145
+ if self.error_summary is not None:
146
+ summary, timestamp = self.error_summary
147
+ timestamp_string = timestamp.strftime("%Y-%m-%d %H:%M:%S")
148
+ self.write_fp.write(f"\n{colored('Exception summary', 'grey', bold=True)}")
149
+ self.write_fp.write(f" {colored(timestamp_string, 'grey')}")
150
+ self.write_fp.write(f"\n{summary}")
151
+
152
+ def write(self, line: LogLine) -> None:
153
+ self.write_separator()
154
+ self.write_state_window(line)
155
+ self.write_log_window(line)
156
+ self.write_queues()
157
+ self.write_error_summary_to_screen()
158
+ sys.stdout.flush()
159
+
160
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
161
+ self.error_summary = error_summary.message, datetime.datetime.now()
162
+
163
+ def write_error(self, error: LogError) -> None:
164
+ self.errors.append((error.message_with_location, datetime.datetime.now()))
165
+
166
+ def write_status(self, status: LogStatus) -> None:
167
+ self.statuses.append((status.message, datetime.datetime.now()))
168
+
169
+ def write_ping(self, ping: LogPing) -> None:
170
+ self.pings.append((ping.message, datetime.datetime.now()))
@@ -0,0 +1,223 @@
1
+ """Defines a Tensorboard logger backend."""
2
+
3
+ import atexit
4
+ import functools
5
+ import logging
6
+ import os
7
+ import re
8
+ import shutil
9
+ import subprocess
10
+ import threading
11
+ import time
12
+ from pathlib import Path
13
+ from typing import TypeVar
14
+
15
+ from omegaconf import DictConfig, OmegaConf
16
+
17
+ from xax.core.state import Phase
18
+ from xax.nn.parallel import is_master
19
+ from xax.task.logger import LoggerImpl, LogLine
20
+ from xax.utils.jax import as_float
21
+ from xax.utils.logging import LOG_STATUS, port_is_busy
22
+ from xax.utils.tensorboard import TensorboardWriter, TensorboardWriters
23
+
24
+ logger: logging.Logger = logging.getLogger(__name__)
25
+
26
+ T = TypeVar("T")
27
+
28
+ DEFAULT_TENSORBOARD_PORT = 9249
29
+
30
+
31
+ class TensorboardLogger(LoggerImpl):
32
+ def __init__(
33
+ self,
34
+ run_directory: str | Path,
35
+ subdirectory: str = "tensorboard",
36
+ flush_seconds: float = 10.0,
37
+ wait_seconds: float = 0.0,
38
+ start_in_subprocess: bool = True,
39
+ use_localhost: bool = False,
40
+ log_interval_seconds: float = 10.0,
41
+ ) -> None:
42
+ """Defines a logger which writes to Tensorboard.
43
+
44
+ Args:
45
+ run_directory: The root run directory.
46
+ subdirectory: The subdirectory of the run directory to write
47
+ Tensorboard logs to.
48
+ flush_seconds: How often to flush logs.
49
+ wait_seconds: Time to wait before starting Tensorboard process.
50
+ start_in_subprocess: Start TensorBoard subprocess.
51
+ use_localhost: Use localhost for TensorBoard address.
52
+ log_interval_seconds: The interval between successive log lines.
53
+ """
54
+ super().__init__(log_interval_seconds)
55
+
56
+ self.log_directory = Path(run_directory).expanduser().resolve() / subdirectory
57
+ self.wait_seconds = wait_seconds
58
+ self.start_in_subprocess = start_in_subprocess
59
+ self.use_localhost = use_localhost
60
+
61
+ self.proc: subprocess.Popen | None = None
62
+
63
+ self.git_state: str | None = None
64
+ self.training_code: str | None = None
65
+ self.config: DictConfig | None = None
66
+
67
+ self.writers = TensorboardWriters(log_directory=self.log_directory, flush_seconds=flush_seconds)
68
+ self._started = False
69
+
70
+ def _start(self) -> None:
71
+ if self._started:
72
+ return
73
+
74
+ if is_master():
75
+ threading.Thread(target=self.worker_thread, daemon=True).start()
76
+
77
+ self._started = True
78
+
79
+ def worker_thread(self) -> None:
80
+ time.sleep(self.wait_seconds)
81
+
82
+ port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
83
+
84
+ while port_is_busy(port):
85
+ logger.warning("Port %s is busy, waiting...", port)
86
+ time.sleep(10)
87
+
88
+ def make_localhost(s: str) -> str:
89
+ if self.use_localhost:
90
+ s = re.sub(rf"://(.+?):{port}", f"://localhost:{port}", s)
91
+ return s
92
+
93
+ def parse_url(s: str) -> str:
94
+ m = re.search(r" (http\S+?) ", s)
95
+ if m is None:
96
+ return s
97
+ return f"Tensorboard: {m.group(1)}"
98
+
99
+ command: list[str] = [
100
+ "python",
101
+ "-m",
102
+ "tensorboard.main",
103
+ "serve",
104
+ "--logdir",
105
+ str(self.log_directory),
106
+ "--bind_all",
107
+ "--port",
108
+ str(port),
109
+ "--reload_interval",
110
+ "15",
111
+ ]
112
+
113
+ if not self.start_in_subprocess:
114
+ logger.warning("Tensorboard subprocess disabled because start_in_subprocess=False")
115
+
116
+ else:
117
+ self.proc = subprocess.Popen( # pylint: disable=consider-using-with
118
+ command,
119
+ stdout=subprocess.PIPE,
120
+ stderr=subprocess.STDOUT,
121
+ )
122
+
123
+ # Gets the output line that shows the running address.
124
+ assert self.proc is not None and self.proc.stdout is not None
125
+ lines = []
126
+ for line in self.proc.stdout:
127
+ line_str = line.decode("utf-8")
128
+ if line_str.startswith("TensorBoard"):
129
+ line_str = parse_url(make_localhost(line_str))
130
+ logging.log(LOG_STATUS, line_str)
131
+ break
132
+ lines.append(line_str)
133
+ else:
134
+ line_str = "".join(lines)
135
+ raise RuntimeError(f"Tensorboard failed to start:\n{line_str}")
136
+
137
+ atexit.register(self.cleanup)
138
+
139
+ def cleanup(self) -> None:
140
+ if self.proc is not None:
141
+ self.proc.terminate()
142
+ self.proc.wait()
143
+ self.proc = None
144
+
145
+ def __del__(self) -> None:
146
+ self.cleanup()
147
+
148
+ @functools.lru_cache(None) # Avoid clearing logs multiple times.
149
+ def clear_logs(self) -> None:
150
+ if not self.log_directory.exists():
151
+ return
152
+ if not any(child.is_dir() for child in self.log_directory.iterdir()):
153
+ return
154
+ logger.warning("Clearing TensorBoard logs")
155
+ shutil.rmtree(self.log_directory)
156
+
157
+ def get_writer(self, phase: Phase) -> TensorboardWriter:
158
+ self._start()
159
+ return self.writers.writer(phase)
160
+
161
+ def log_git_state(self, git_state: str) -> None:
162
+ if not is_master():
163
+ return
164
+ self.git_state = f"```\n{git_state}\n```"
165
+
166
+ def log_training_code(self, training_code: str) -> None:
167
+ if not is_master():
168
+ return
169
+ self.training_code = f"```python\n{training_code}\n```"
170
+
171
+ def log_config(self, config: DictConfig) -> None:
172
+ if not is_master():
173
+ return
174
+ self.config = config
175
+
176
+ def write(self, line: LogLine) -> None:
177
+ if not is_master():
178
+ return
179
+
180
+ if line.state.num_steps == 0:
181
+ self.clear_logs()
182
+
183
+ writer = self.get_writer(line.state.phase)
184
+ walltime = line.state.start_time_s + line.state.elapsed_time_s
185
+
186
+ for namespace, scalars in line.scalars.items():
187
+ for scalar_key, scalar_value in scalars.items():
188
+ writer.add_scalar(
189
+ f"{namespace}/{scalar_key}",
190
+ as_float(scalar_value),
191
+ global_step=line.state.num_steps,
192
+ walltime=walltime,
193
+ )
194
+
195
+ for namespace, strings in line.strings.items():
196
+ for string_key, string_value in strings.items():
197
+ writer.add_text(
198
+ f"{namespace}/{string_key}",
199
+ string_value,
200
+ global_step=line.state.num_steps,
201
+ walltime=walltime,
202
+ )
203
+
204
+ for namespace, images in line.images.items():
205
+ for image_key, image_value in images.items():
206
+ writer.add_image(
207
+ f"{namespace}/{image_key}",
208
+ image_value.image,
209
+ global_step=line.state.num_steps,
210
+ walltime=walltime,
211
+ )
212
+
213
+ if self.config is not None:
214
+ writer.add_text("config", f"```\n{OmegaConf.to_yaml(self.config)}\n```")
215
+ self.config = None
216
+
217
+ if self.git_state is not None:
218
+ writer.add_text("git", self.git_state)
219
+ self.git_state = None
220
+
221
+ if self.training_code is not None:
222
+ writer.add_text("code", self.training_code)
223
+ self.training_code = None
@@ -0,0 +1,12 @@
1
+ """Defines a single interface for all the mixins."""
2
+
3
+ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
4
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
5
+ from xax.task.mixins.cpu_stats import CPUStatsConfig, CPUStatsMixin
6
+ from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
7
+ from xax.task.mixins.gpu_stats import GPUStatsConfig, GPUStatsMixin
8
+ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
9
+ from xax.task.mixins.process import ProcessConfig, ProcessMixin
10
+ from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
11
+ from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
12
+ from xax.task.mixins.train import TrainConfig, TrainMixin