xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
File without changes
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""Defines a logger that calls a callback function with the log line."""
|
2
|
+
|
3
|
+
from typing import Callable
|
4
|
+
|
5
|
+
from omegaconf import DictConfig
|
6
|
+
|
7
|
+
from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
|
8
|
+
|
9
|
+
|
10
|
+
class CallbackLogger(LoggerImpl):
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
*,
|
14
|
+
callback: Callable[[LogLine], None] = lambda x: None,
|
15
|
+
error_summary_callback: Callable[[LogErrorSummary], None] = lambda x: None,
|
16
|
+
error_callback: Callable[[LogError], None] = lambda x: None,
|
17
|
+
status_callback: Callable[[LogStatus], None] = lambda x: None,
|
18
|
+
ping_callback: Callable[[LogPing], None] = lambda x: None,
|
19
|
+
git_state_callback: Callable[[str], None] = lambda x: None,
|
20
|
+
training_code_callback: Callable[[str], None] = lambda x: None,
|
21
|
+
config_callback: Callable[[DictConfig], None] = lambda x: None,
|
22
|
+
) -> None:
|
23
|
+
super().__init__()
|
24
|
+
|
25
|
+
self.callback = callback
|
26
|
+
self.error_summary_callback = error_summary_callback
|
27
|
+
self.error_callback = error_callback
|
28
|
+
self.status_callback = status_callback
|
29
|
+
self.ping_callback = ping_callback
|
30
|
+
self.git_state_callback = git_state_callback
|
31
|
+
self.training_code_callback = training_code_callback
|
32
|
+
self.config_callback = config_callback
|
33
|
+
|
34
|
+
def write(self, line: LogLine) -> None:
|
35
|
+
self.callback(line)
|
36
|
+
|
37
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
38
|
+
self.error_summary_callback(error_summary)
|
39
|
+
|
40
|
+
def write_error(self, error: LogError) -> None:
|
41
|
+
self.error_callback(error)
|
42
|
+
|
43
|
+
def write_status(self, status: LogStatus) -> None:
|
44
|
+
self.status_callback(status)
|
45
|
+
|
46
|
+
def write_ping(self, ping: LogPing) -> None:
|
47
|
+
self.ping_callback(ping)
|
48
|
+
|
49
|
+
def log_git_state(self, git_state: str) -> None:
|
50
|
+
self.git_state_callback(git_state)
|
51
|
+
|
52
|
+
def log_training_code(self, training_code: str) -> None:
|
53
|
+
self.training_code_callback(training_code)
|
54
|
+
|
55
|
+
def log_config(self, config: DictConfig) -> None:
|
56
|
+
self.config_callback(config)
|
xax/task/loggers/json.py
ADDED
@@ -0,0 +1,121 @@
|
|
1
|
+
"""Defines a logger which logs JSON lines to a file."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import sys
|
5
|
+
from dataclasses import asdict
|
6
|
+
from typing import Any, Literal, TextIO
|
7
|
+
|
8
|
+
from jaxtyping import Array
|
9
|
+
|
10
|
+
from xax.task.logger import LogError, LoggerImpl, LogLine, LogPing, LogStatus
|
11
|
+
|
12
|
+
|
13
|
+
def get_json_value(value: Any) -> Any: # noqa: ANN401
|
14
|
+
if isinstance(value, Array):
|
15
|
+
value = value.item()
|
16
|
+
return value
|
17
|
+
|
18
|
+
|
19
|
+
class JsonLogger(LoggerImpl):
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
log_stream: TextIO = sys.stdout,
|
23
|
+
err_log_stream: TextIO = sys.stderr,
|
24
|
+
flush_immediately: bool = False,
|
25
|
+
open_mode: Literal["w", "a"] = "w",
|
26
|
+
line_sep: str = "\n",
|
27
|
+
remove_unicode_from_namespaces: bool = True,
|
28
|
+
log_interval_seconds: float = 10.0,
|
29
|
+
) -> None:
|
30
|
+
"""Defines a simpler logger which logs to stdout.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
log_stream: The stream to log to.
|
34
|
+
err_log_stream: The stream to log errors to.
|
35
|
+
flush_immediately: Whether to flush the file after every write.
|
36
|
+
open_mode: The file open mode.
|
37
|
+
line_sep: The line separator to use.
|
38
|
+
remove_unicode_from_namespaces: Whether to remove unicode from
|
39
|
+
namespaces. This is the typical behavior for namespaces that
|
40
|
+
use ASCII art for visibility in other logs, but in the JSON
|
41
|
+
log file should be ignored.
|
42
|
+
log_interval_seconds: The interval between successive log lines.
|
43
|
+
"""
|
44
|
+
super().__init__(log_interval_seconds)
|
45
|
+
|
46
|
+
self.log_stream = log_stream
|
47
|
+
self.err_log_stream = err_log_stream
|
48
|
+
self.flush_immediately = flush_immediately
|
49
|
+
self.open_mode = open_mode
|
50
|
+
self.line_sep = line_sep
|
51
|
+
self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
|
52
|
+
|
53
|
+
@property
|
54
|
+
def fp(self) -> TextIO:
|
55
|
+
return self.log_stream
|
56
|
+
|
57
|
+
@property
|
58
|
+
def err_fp(self) -> TextIO:
|
59
|
+
return self.err_log_stream
|
60
|
+
|
61
|
+
def get_json(self, line: LogLine) -> str:
|
62
|
+
data: dict = {"state": asdict(line.state)}
|
63
|
+
|
64
|
+
def add_logs(log: dict[str, dict[str, Any]], data: dict) -> None:
|
65
|
+
for namespace, values in log.items():
|
66
|
+
if self.remove_unicode_from_namespaces:
|
67
|
+
namespace = namespace.encode("ascii", errors="ignore").decode("ascii").strip()
|
68
|
+
if namespace not in data:
|
69
|
+
data[namespace] = {}
|
70
|
+
for k, v in values.items():
|
71
|
+
data[namespace][k] = get_json_value(v)
|
72
|
+
|
73
|
+
add_logs(line.scalars, data)
|
74
|
+
add_logs(line.strings, data)
|
75
|
+
return json.dumps(data)
|
76
|
+
|
77
|
+
def write(self, line: LogLine) -> None:
|
78
|
+
self.fp.write(self.get_json(line))
|
79
|
+
self.fp.write(self.line_sep)
|
80
|
+
if self.flush_immediately:
|
81
|
+
self.fp.flush()
|
82
|
+
|
83
|
+
def write_error(self, error: LogError) -> None:
|
84
|
+
self.err_fp.write(error.message)
|
85
|
+
if error.location is not None:
|
86
|
+
self.err_fp.write(f" ({error.location})")
|
87
|
+
self.err_fp.write(self.line_sep)
|
88
|
+
if self.flush_immediately:
|
89
|
+
self.err_fp.flush()
|
90
|
+
|
91
|
+
def write_ping(self, ping: LogPing) -> None:
|
92
|
+
self.fp.write(
|
93
|
+
json.dumps(
|
94
|
+
{
|
95
|
+
"ping": {
|
96
|
+
"message": ping.message,
|
97
|
+
"filename": ping.filename,
|
98
|
+
"lineno": ping.lineno,
|
99
|
+
},
|
100
|
+
},
|
101
|
+
),
|
102
|
+
)
|
103
|
+
self.fp.write(self.line_sep)
|
104
|
+
if self.flush_immediately:
|
105
|
+
self.fp.flush()
|
106
|
+
|
107
|
+
def write_status(self, status: LogStatus) -> None:
|
108
|
+
self.fp.write(
|
109
|
+
json.dumps(
|
110
|
+
{
|
111
|
+
"status": {
|
112
|
+
"message": status.message,
|
113
|
+
"filename": status.filename,
|
114
|
+
"lineno": status.lineno,
|
115
|
+
},
|
116
|
+
},
|
117
|
+
),
|
118
|
+
)
|
119
|
+
self.fp.write(self.line_sep)
|
120
|
+
if self.flush_immediately:
|
121
|
+
self.fp.flush()
|
@@ -0,0 +1,45 @@
|
|
1
|
+
"""Defines a logger which logs the current training state."""
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Literal
|
5
|
+
|
6
|
+
from omegaconf import DictConfig, OmegaConf
|
7
|
+
|
8
|
+
from xax.task.logger import LoggerImpl, LogLine
|
9
|
+
|
10
|
+
|
11
|
+
class StateLogger(LoggerImpl):
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
run_directory: str | Path,
|
15
|
+
git_state_name: str = "git_state.txt",
|
16
|
+
train_code_name: str = "train_code.py",
|
17
|
+
config_name: str = "config.yaml",
|
18
|
+
flush_immediately: bool = False,
|
19
|
+
open_mode: Literal["w", "a"] = "w",
|
20
|
+
line_sep: str = "\n",
|
21
|
+
remove_unicode_from_namespaces: bool = True,
|
22
|
+
) -> None:
|
23
|
+
super().__init__(float("inf"))
|
24
|
+
|
25
|
+
self.git_state_file = Path(run_directory).expanduser().resolve() / git_state_name
|
26
|
+
self.train_code_file = Path(run_directory).expanduser().resolve() / train_code_name
|
27
|
+
self.config_file = Path(run_directory).expanduser().resolve() / config_name
|
28
|
+
self.flush_immediately = flush_immediately
|
29
|
+
self.open_mode = open_mode
|
30
|
+
self.line_sep = line_sep
|
31
|
+
self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
|
32
|
+
|
33
|
+
def log_git_state(self, git_state: str) -> None:
|
34
|
+
with open(self.git_state_file, "w") as f:
|
35
|
+
f.write(git_state)
|
36
|
+
|
37
|
+
def log_training_code(self, training_code: str) -> None:
|
38
|
+
with open(self.train_code_file, "w") as f:
|
39
|
+
f.write(training_code)
|
40
|
+
|
41
|
+
def log_config(self, config: DictConfig) -> None:
|
42
|
+
OmegaConf.save(config, self.config_file)
|
43
|
+
|
44
|
+
def write(self, line: LogLine) -> None:
|
45
|
+
pass
|
@@ -0,0 +1,170 @@
|
|
1
|
+
"""Defines a logger that logs to stdout."""
|
2
|
+
|
3
|
+
import datetime
|
4
|
+
import logging
|
5
|
+
import sys
|
6
|
+
from collections import deque
|
7
|
+
from typing import Any, Deque, TextIO
|
8
|
+
|
9
|
+
from jaxtyping import Array
|
10
|
+
|
11
|
+
from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
|
12
|
+
from xax.utils.text import Color, colored, format_timedelta
|
13
|
+
|
14
|
+
|
15
|
+
def format_number(value: int | float, precision: int) -> str:
|
16
|
+
if isinstance(value, int):
|
17
|
+
return str(value)
|
18
|
+
return f"{value:.{precision}g}"
|
19
|
+
|
20
|
+
|
21
|
+
def as_str(value: Any, precision: int) -> str: # noqa: ANN401
|
22
|
+
if isinstance(value, str):
|
23
|
+
return f'"{value}"'
|
24
|
+
if isinstance(value, Array):
|
25
|
+
value = value.item()
|
26
|
+
if isinstance(value, (int, float)):
|
27
|
+
return format_number(value, precision)
|
28
|
+
raise TypeError(f"Unexpected log type: {type(value)}")
|
29
|
+
|
30
|
+
|
31
|
+
class StdoutLogger(LoggerImpl):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
write_fp: TextIO = sys.stdout,
|
35
|
+
precision: int = 4,
|
36
|
+
log_timers: bool = False,
|
37
|
+
log_perf: bool = False,
|
38
|
+
log_optim: bool = False,
|
39
|
+
log_fp: bool = False,
|
40
|
+
log_interval_seconds: float = 1.0,
|
41
|
+
remove_temporary_after: datetime.timedelta = datetime.timedelta(seconds=10),
|
42
|
+
) -> None:
|
43
|
+
"""Defines a logger which shows a pop-up using Curses.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
write_fp: The file to write logs to.
|
47
|
+
precision: The integer precision to use when logging scalars.
|
48
|
+
log_timers: Whether to log timers.
|
49
|
+
log_perf: Whether to log performance metrics.
|
50
|
+
log_optim: Whether to log optimizer parameters.
|
51
|
+
log_fp: Whether to log floating point parameters.
|
52
|
+
log_interval_seconds: The interval between successive log lines.
|
53
|
+
remove_temporary_after: The time after which temporary toasts
|
54
|
+
are removed.
|
55
|
+
"""
|
56
|
+
super().__init__(log_interval_seconds)
|
57
|
+
|
58
|
+
self.write_fp = write_fp
|
59
|
+
self.log_timers = log_timers
|
60
|
+
self.log_perf = log_perf
|
61
|
+
self.log_fp = log_fp
|
62
|
+
self.log_optim = log_optim
|
63
|
+
self.precision = precision
|
64
|
+
self.remove_temporary_after = remove_temporary_after
|
65
|
+
self.logger = logging.getLogger("stdout")
|
66
|
+
|
67
|
+
self.statuses: Deque[tuple[str, datetime.datetime]] = deque()
|
68
|
+
self.pings: Deque[tuple[str, datetime.datetime]] = deque()
|
69
|
+
self.errors: Deque[tuple[str, datetime.datetime]] = deque()
|
70
|
+
self.error_summary: tuple[str, datetime.datetime] | None = None
|
71
|
+
|
72
|
+
def start(self) -> None:
|
73
|
+
return super().start()
|
74
|
+
|
75
|
+
def stop(self) -> None:
|
76
|
+
self.write_queues()
|
77
|
+
return super().stop()
|
78
|
+
|
79
|
+
def write_separator(self) -> None:
|
80
|
+
self.write_fp.write("\033[2J\033[H")
|
81
|
+
|
82
|
+
def write_state_window(self, line: LogLine) -> None:
|
83
|
+
elapsed_time = format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True)
|
84
|
+
state_info = {
|
85
|
+
"Steps": f"{line.state.num_steps}",
|
86
|
+
"Samples": f"{line.state.num_samples}",
|
87
|
+
"Elapsed Time": f"{elapsed_time}",
|
88
|
+
}
|
89
|
+
|
90
|
+
colored_prefix = colored("Phase: ", "grey", bold=True)
|
91
|
+
colored_phase = colored(line.state.phase, "green" if line.state.phase == "train" else "yellow", bold=True)
|
92
|
+
self.write_fp.write(f"{colored_prefix}{colored_phase}\n")
|
93
|
+
for k, v in state_info.items():
|
94
|
+
self.write_fp.write(f" ↪ {k}: {colored(v, 'cyan')}\n")
|
95
|
+
|
96
|
+
def write_log_window(self, line: LogLine) -> None:
|
97
|
+
namespace_to_lines: dict[str, dict[str, str]] = {}
|
98
|
+
|
99
|
+
def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
|
100
|
+
for namespace, values in log.items():
|
101
|
+
if not self.log_timers and namespace.startswith("⏰"):
|
102
|
+
continue
|
103
|
+
if not self.log_perf and namespace.startswith("🔧"):
|
104
|
+
continue
|
105
|
+
if not self.log_optim and namespace.startswith("📉"):
|
106
|
+
continue
|
107
|
+
if not self.log_fp and namespace.startswith("⚖️"):
|
108
|
+
continue
|
109
|
+
if namespace not in namespace_to_lines:
|
110
|
+
namespace_to_lines[namespace] = {}
|
111
|
+
for k, v in values.items():
|
112
|
+
v_str = as_str(v, self.precision)
|
113
|
+
namespace_to_lines[namespace][k] = v_str
|
114
|
+
|
115
|
+
add_logs(line.scalars, namespace_to_lines)
|
116
|
+
add_logs(line.strings, namespace_to_lines)
|
117
|
+
if not namespace_to_lines:
|
118
|
+
return
|
119
|
+
|
120
|
+
self.write_fp.write("\n")
|
121
|
+
for namespace, lines in sorted(namespace_to_lines.items()):
|
122
|
+
self.write_fp.write(f"{colored(namespace, 'cyan', bold=True)}\n")
|
123
|
+
for k, v in lines.items():
|
124
|
+
self.write_fp.write(f" ↪ {k}: {v}\n")
|
125
|
+
|
126
|
+
def write_queue(self, title: str, q: Deque[tuple[str, datetime.datetime]], remove: bool, color: Color) -> None:
|
127
|
+
if not q:
|
128
|
+
return
|
129
|
+
|
130
|
+
self.write_fp.write(f"\n{colored(title, 'grey', bold=True)}\n")
|
131
|
+
self.write_fp.write("\n".join(f" ✦ {colored(msg, color)}" for msg, _ in reversed(q)))
|
132
|
+
self.write_fp.write("\n")
|
133
|
+
|
134
|
+
if remove:
|
135
|
+
now = datetime.datetime.now()
|
136
|
+
while q and now - q[0][1] > self.remove_temporary_after:
|
137
|
+
q.popleft()
|
138
|
+
|
139
|
+
def write_queues(self) -> None:
|
140
|
+
self.write_queue("Status", self.statuses, False, "green")
|
141
|
+
self.write_queue("Pings", self.pings, True, "cyan")
|
142
|
+
self.write_queue("Errors", self.errors, False, "red")
|
143
|
+
|
144
|
+
def write_error_summary_to_screen(self) -> None:
|
145
|
+
if self.error_summary is not None:
|
146
|
+
summary, timestamp = self.error_summary
|
147
|
+
timestamp_string = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
148
|
+
self.write_fp.write(f"\n{colored('Exception summary', 'grey', bold=True)}")
|
149
|
+
self.write_fp.write(f" {colored(timestamp_string, 'grey')}")
|
150
|
+
self.write_fp.write(f"\n{summary}")
|
151
|
+
|
152
|
+
def write(self, line: LogLine) -> None:
|
153
|
+
self.write_separator()
|
154
|
+
self.write_state_window(line)
|
155
|
+
self.write_log_window(line)
|
156
|
+
self.write_queues()
|
157
|
+
self.write_error_summary_to_screen()
|
158
|
+
sys.stdout.flush()
|
159
|
+
|
160
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
161
|
+
self.error_summary = error_summary.message, datetime.datetime.now()
|
162
|
+
|
163
|
+
def write_error(self, error: LogError) -> None:
|
164
|
+
self.errors.append((error.message_with_location, datetime.datetime.now()))
|
165
|
+
|
166
|
+
def write_status(self, status: LogStatus) -> None:
|
167
|
+
self.statuses.append((status.message, datetime.datetime.now()))
|
168
|
+
|
169
|
+
def write_ping(self, ping: LogPing) -> None:
|
170
|
+
self.pings.append((ping.message, datetime.datetime.now()))
|
@@ -0,0 +1,223 @@
|
|
1
|
+
"""Defines a Tensorboard logger backend."""
|
2
|
+
|
3
|
+
import atexit
|
4
|
+
import functools
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
import re
|
8
|
+
import shutil
|
9
|
+
import subprocess
|
10
|
+
import threading
|
11
|
+
import time
|
12
|
+
from pathlib import Path
|
13
|
+
from typing import TypeVar
|
14
|
+
|
15
|
+
from omegaconf import DictConfig, OmegaConf
|
16
|
+
|
17
|
+
from xax.core.state import Phase
|
18
|
+
from xax.nn.parallel import is_master
|
19
|
+
from xax.task.logger import LoggerImpl, LogLine
|
20
|
+
from xax.utils.jax import as_float
|
21
|
+
from xax.utils.logging import LOG_STATUS, port_is_busy
|
22
|
+
from xax.utils.tensorboard import TensorboardWriter, TensorboardWriters
|
23
|
+
|
24
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
T = TypeVar("T")
|
27
|
+
|
28
|
+
DEFAULT_TENSORBOARD_PORT = 9249
|
29
|
+
|
30
|
+
|
31
|
+
class TensorboardLogger(LoggerImpl):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
run_directory: str | Path,
|
35
|
+
subdirectory: str = "tensorboard",
|
36
|
+
flush_seconds: float = 10.0,
|
37
|
+
wait_seconds: float = 0.0,
|
38
|
+
start_in_subprocess: bool = True,
|
39
|
+
use_localhost: bool = False,
|
40
|
+
log_interval_seconds: float = 10.0,
|
41
|
+
) -> None:
|
42
|
+
"""Defines a logger which writes to Tensorboard.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
run_directory: The root run directory.
|
46
|
+
subdirectory: The subdirectory of the run directory to write
|
47
|
+
Tensorboard logs to.
|
48
|
+
flush_seconds: How often to flush logs.
|
49
|
+
wait_seconds: Time to wait before starting Tensorboard process.
|
50
|
+
start_in_subprocess: Start TensorBoard subprocess.
|
51
|
+
use_localhost: Use localhost for TensorBoard address.
|
52
|
+
log_interval_seconds: The interval between successive log lines.
|
53
|
+
"""
|
54
|
+
super().__init__(log_interval_seconds)
|
55
|
+
|
56
|
+
self.log_directory = Path(run_directory).expanduser().resolve() / subdirectory
|
57
|
+
self.wait_seconds = wait_seconds
|
58
|
+
self.start_in_subprocess = start_in_subprocess
|
59
|
+
self.use_localhost = use_localhost
|
60
|
+
|
61
|
+
self.proc: subprocess.Popen | None = None
|
62
|
+
|
63
|
+
self.git_state: str | None = None
|
64
|
+
self.training_code: str | None = None
|
65
|
+
self.config: DictConfig | None = None
|
66
|
+
|
67
|
+
self.writers = TensorboardWriters(log_directory=self.log_directory, flush_seconds=flush_seconds)
|
68
|
+
self._started = False
|
69
|
+
|
70
|
+
def _start(self) -> None:
|
71
|
+
if self._started:
|
72
|
+
return
|
73
|
+
|
74
|
+
if is_master():
|
75
|
+
threading.Thread(target=self.worker_thread, daemon=True).start()
|
76
|
+
|
77
|
+
self._started = True
|
78
|
+
|
79
|
+
def worker_thread(self) -> None:
|
80
|
+
time.sleep(self.wait_seconds)
|
81
|
+
|
82
|
+
port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
|
83
|
+
|
84
|
+
while port_is_busy(port):
|
85
|
+
logger.warning("Port %s is busy, waiting...", port)
|
86
|
+
time.sleep(10)
|
87
|
+
|
88
|
+
def make_localhost(s: str) -> str:
|
89
|
+
if self.use_localhost:
|
90
|
+
s = re.sub(rf"://(.+?):{port}", f"://localhost:{port}", s)
|
91
|
+
return s
|
92
|
+
|
93
|
+
def parse_url(s: str) -> str:
|
94
|
+
m = re.search(r" (http\S+?) ", s)
|
95
|
+
if m is None:
|
96
|
+
return s
|
97
|
+
return f"Tensorboard: {m.group(1)}"
|
98
|
+
|
99
|
+
command: list[str] = [
|
100
|
+
"python",
|
101
|
+
"-m",
|
102
|
+
"tensorboard.main",
|
103
|
+
"serve",
|
104
|
+
"--logdir",
|
105
|
+
str(self.log_directory),
|
106
|
+
"--bind_all",
|
107
|
+
"--port",
|
108
|
+
str(port),
|
109
|
+
"--reload_interval",
|
110
|
+
"15",
|
111
|
+
]
|
112
|
+
|
113
|
+
if not self.start_in_subprocess:
|
114
|
+
logger.warning("Tensorboard subprocess disabled because start_in_subprocess=False")
|
115
|
+
|
116
|
+
else:
|
117
|
+
self.proc = subprocess.Popen( # pylint: disable=consider-using-with
|
118
|
+
command,
|
119
|
+
stdout=subprocess.PIPE,
|
120
|
+
stderr=subprocess.STDOUT,
|
121
|
+
)
|
122
|
+
|
123
|
+
# Gets the output line that shows the running address.
|
124
|
+
assert self.proc is not None and self.proc.stdout is not None
|
125
|
+
lines = []
|
126
|
+
for line in self.proc.stdout:
|
127
|
+
line_str = line.decode("utf-8")
|
128
|
+
if line_str.startswith("TensorBoard"):
|
129
|
+
line_str = parse_url(make_localhost(line_str))
|
130
|
+
logging.log(LOG_STATUS, line_str)
|
131
|
+
break
|
132
|
+
lines.append(line_str)
|
133
|
+
else:
|
134
|
+
line_str = "".join(lines)
|
135
|
+
raise RuntimeError(f"Tensorboard failed to start:\n{line_str}")
|
136
|
+
|
137
|
+
atexit.register(self.cleanup)
|
138
|
+
|
139
|
+
def cleanup(self) -> None:
|
140
|
+
if self.proc is not None:
|
141
|
+
self.proc.terminate()
|
142
|
+
self.proc.wait()
|
143
|
+
self.proc = None
|
144
|
+
|
145
|
+
def __del__(self) -> None:
|
146
|
+
self.cleanup()
|
147
|
+
|
148
|
+
@functools.lru_cache(None) # Avoid clearing logs multiple times.
|
149
|
+
def clear_logs(self) -> None:
|
150
|
+
if not self.log_directory.exists():
|
151
|
+
return
|
152
|
+
if not any(child.is_dir() for child in self.log_directory.iterdir()):
|
153
|
+
return
|
154
|
+
logger.warning("Clearing TensorBoard logs")
|
155
|
+
shutil.rmtree(self.log_directory)
|
156
|
+
|
157
|
+
def get_writer(self, phase: Phase) -> TensorboardWriter:
|
158
|
+
self._start()
|
159
|
+
return self.writers.writer(phase)
|
160
|
+
|
161
|
+
def log_git_state(self, git_state: str) -> None:
|
162
|
+
if not is_master():
|
163
|
+
return
|
164
|
+
self.git_state = f"```\n{git_state}\n```"
|
165
|
+
|
166
|
+
def log_training_code(self, training_code: str) -> None:
|
167
|
+
if not is_master():
|
168
|
+
return
|
169
|
+
self.training_code = f"```python\n{training_code}\n```"
|
170
|
+
|
171
|
+
def log_config(self, config: DictConfig) -> None:
|
172
|
+
if not is_master():
|
173
|
+
return
|
174
|
+
self.config = config
|
175
|
+
|
176
|
+
def write(self, line: LogLine) -> None:
|
177
|
+
if not is_master():
|
178
|
+
return
|
179
|
+
|
180
|
+
if line.state.num_steps == 0:
|
181
|
+
self.clear_logs()
|
182
|
+
|
183
|
+
writer = self.get_writer(line.state.phase)
|
184
|
+
walltime = line.state.start_time_s + line.state.elapsed_time_s
|
185
|
+
|
186
|
+
for namespace, scalars in line.scalars.items():
|
187
|
+
for scalar_key, scalar_value in scalars.items():
|
188
|
+
writer.add_scalar(
|
189
|
+
f"{namespace}/{scalar_key}",
|
190
|
+
as_float(scalar_value),
|
191
|
+
global_step=line.state.num_steps,
|
192
|
+
walltime=walltime,
|
193
|
+
)
|
194
|
+
|
195
|
+
for namespace, strings in line.strings.items():
|
196
|
+
for string_key, string_value in strings.items():
|
197
|
+
writer.add_text(
|
198
|
+
f"{namespace}/{string_key}",
|
199
|
+
string_value,
|
200
|
+
global_step=line.state.num_steps,
|
201
|
+
walltime=walltime,
|
202
|
+
)
|
203
|
+
|
204
|
+
for namespace, images in line.images.items():
|
205
|
+
for image_key, image_value in images.items():
|
206
|
+
writer.add_image(
|
207
|
+
f"{namespace}/{image_key}",
|
208
|
+
image_value.image,
|
209
|
+
global_step=line.state.num_steps,
|
210
|
+
walltime=walltime,
|
211
|
+
)
|
212
|
+
|
213
|
+
if self.config is not None:
|
214
|
+
writer.add_text("config", f"```\n{OmegaConf.to_yaml(self.config)}\n```")
|
215
|
+
self.config = None
|
216
|
+
|
217
|
+
if self.git_state is not None:
|
218
|
+
writer.add_text("git", self.git_state)
|
219
|
+
self.git_state = None
|
220
|
+
|
221
|
+
if self.training_code is not None:
|
222
|
+
writer.add_text("code", self.training_code)
|
223
|
+
self.training_code = None
|
@@ -0,0 +1,12 @@
|
|
1
|
+
"""Defines a single interface for all the mixins."""
|
2
|
+
|
3
|
+
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
4
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
5
|
+
from xax.task.mixins.cpu_stats import CPUStatsConfig, CPUStatsMixin
|
6
|
+
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
7
|
+
from xax.task.mixins.gpu_stats import GPUStatsConfig, GPUStatsMixin
|
8
|
+
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
9
|
+
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
10
|
+
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
11
|
+
from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
12
|
+
from xax.task.mixins.train import TrainConfig, TrainMixin
|