xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/utils/jax.py
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
"""Defines some utility functions for interfacing with Jax."""
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
Number = int | float | np.ndarray | jnp.ndarray
|
7
|
+
|
8
|
+
|
9
|
+
def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
|
10
|
+
if isinstance(value, (int, float)):
|
11
|
+
return float(value)
|
12
|
+
if isinstance(value, (np.ndarray, jnp.ndarray)):
|
13
|
+
return float(value.item())
|
14
|
+
raise TypeError(f"Unexpected type: {type(value)}")
|
xax/utils/logging.py
ADDED
@@ -0,0 +1,223 @@
|
|
1
|
+
"""Logging utilities."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import math
|
5
|
+
import socket
|
6
|
+
import sys
|
7
|
+
|
8
|
+
from omegaconf import OmegaConf
|
9
|
+
|
10
|
+
from xax.core.conf import load_user_config
|
11
|
+
from xax.utils.text import Color, color_parts, colored
|
12
|
+
|
13
|
+
# Logging level to show on all ranks.
|
14
|
+
LOG_INFO_ALL: int = logging.INFO + 1
|
15
|
+
LOG_DEBUG_ALL: int = logging.DEBUG + 1
|
16
|
+
|
17
|
+
# Show as a transient message.
|
18
|
+
LOG_PING: int = logging.INFO + 2
|
19
|
+
|
20
|
+
# Show as a persistent status message.
|
21
|
+
LOG_STATUS: int = logging.INFO + 3
|
22
|
+
|
23
|
+
# Reserved for error summary.
|
24
|
+
LOG_ERROR_SUMMARY: int = logging.INFO + 4
|
25
|
+
|
26
|
+
|
27
|
+
class RankFilter(logging.Filter):
|
28
|
+
def __init__(self, *, rank: int | None = None) -> None:
|
29
|
+
"""Logging filter which filters out INFO logs on non-zero ranks.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
rank: The current rank
|
33
|
+
"""
|
34
|
+
super().__init__()
|
35
|
+
|
36
|
+
self.rank = rank
|
37
|
+
|
38
|
+
# Log using INFOALL to show on all ranks.
|
39
|
+
logging.addLevelName(LOG_INFO_ALL, "INFOALL")
|
40
|
+
logging.addLevelName(LOG_DEBUG_ALL, "DEBUGALL")
|
41
|
+
logging.addLevelName(LOG_PING, "PING")
|
42
|
+
logging.addLevelName(LOG_STATUS, "STATUS")
|
43
|
+
logging.addLevelName(LOG_ERROR_SUMMARY, "ERROR_SUMMARY")
|
44
|
+
|
45
|
+
self.log_all_ranks = {
|
46
|
+
logging.getLevelName(level)
|
47
|
+
for level in (
|
48
|
+
LOG_DEBUG_ALL,
|
49
|
+
LOG_INFO_ALL,
|
50
|
+
LOG_STATUS,
|
51
|
+
logging.CRITICAL,
|
52
|
+
logging.ERROR,
|
53
|
+
logging.WARNING,
|
54
|
+
)
|
55
|
+
}
|
56
|
+
|
57
|
+
self.log_no_ranks = {logging.getLevelName(level) for level in (LOG_ERROR_SUMMARY,)}
|
58
|
+
|
59
|
+
def filter(self, record: logging.LogRecord) -> bool:
|
60
|
+
if record.levelname in self.log_no_ranks:
|
61
|
+
return False
|
62
|
+
if self.rank is None or self.rank == 0:
|
63
|
+
return True
|
64
|
+
if record.levelname in self.log_all_ranks:
|
65
|
+
return True
|
66
|
+
return False
|
67
|
+
|
68
|
+
|
69
|
+
class ColoredFormatter(logging.Formatter):
|
70
|
+
"""Defines a custom formatter for displaying logs."""
|
71
|
+
|
72
|
+
RESET_SEQ = "\033[0m"
|
73
|
+
COLOR_SEQ = "\033[1;%dm"
|
74
|
+
BOLD_SEQ = "\033[1m"
|
75
|
+
|
76
|
+
COLORS: dict[str, Color] = {
|
77
|
+
"WARNING": "yellow",
|
78
|
+
"INFOALL": "magenta",
|
79
|
+
"INFO": "cyan",
|
80
|
+
"DEBUGALL": "grey",
|
81
|
+
"DEBUG": "grey",
|
82
|
+
"CRITICAL": "yellow",
|
83
|
+
"FATAL": "red",
|
84
|
+
"ERROR": "red",
|
85
|
+
"STATUS": "green",
|
86
|
+
"PING": "magenta",
|
87
|
+
}
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
*,
|
92
|
+
prefix: str | None = None,
|
93
|
+
rank: int | None = None,
|
94
|
+
world_size: int | None = None,
|
95
|
+
use_color: bool = True,
|
96
|
+
) -> None:
|
97
|
+
asc_start, asc_end = color_parts("grey")
|
98
|
+
name_start, name_end = color_parts("blue", bold=True)
|
99
|
+
|
100
|
+
message_pre = [
|
101
|
+
"{levelname:^19s}",
|
102
|
+
asc_start,
|
103
|
+
"{asctime}",
|
104
|
+
asc_end,
|
105
|
+
" [",
|
106
|
+
name_start,
|
107
|
+
"{name}",
|
108
|
+
name_end,
|
109
|
+
"]",
|
110
|
+
]
|
111
|
+
message_post = [" {message}"]
|
112
|
+
|
113
|
+
if prefix is not None:
|
114
|
+
message_pre += [" ", colored(prefix, "magenta", bold=True)]
|
115
|
+
|
116
|
+
if rank is not None or world_size is not None:
|
117
|
+
assert rank is not None and world_size is not None
|
118
|
+
digits = int(math.log10(world_size) + 1)
|
119
|
+
message_pre += [f" [{rank:0{digits}d}/{world_size}]"]
|
120
|
+
message = "".join(message_pre + message_post)
|
121
|
+
|
122
|
+
super().__init__(message, style="{", datefmt="%Y-%m-%d %H:%M:%S")
|
123
|
+
|
124
|
+
self.rank = rank
|
125
|
+
self.use_color = use_color
|
126
|
+
|
127
|
+
def format(self, record: logging.LogRecord) -> str:
|
128
|
+
levelname = record.levelname
|
129
|
+
|
130
|
+
match levelname:
|
131
|
+
case "DEBUG":
|
132
|
+
record.levelname = ""
|
133
|
+
case "INFOALL":
|
134
|
+
record.levelname = "INFO"
|
135
|
+
case "DEBUGALL":
|
136
|
+
record.levelname = "DEBUG"
|
137
|
+
|
138
|
+
if record.levelname and self.use_color and levelname in self.COLORS:
|
139
|
+
record.levelname = colored(record.levelname, self.COLORS[levelname], bold=True)
|
140
|
+
return logging.Formatter.format(self, record)
|
141
|
+
|
142
|
+
|
143
|
+
def configure_logging(prefix: str | None = None, *, rank: int | None = None, world_size: int | None = None) -> None:
|
144
|
+
"""Instantiates logging.
|
145
|
+
|
146
|
+
This captures logs and reroutes them to the Toasts module, which is
|
147
|
+
pretty similar to Python logging except that the API is a lot easier to
|
148
|
+
interact with.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
prefix: An optional prefix to add to the logger
|
152
|
+
rank: The current rank, or None if not using multiprocessing
|
153
|
+
world_size: The total world size, or None if not using multiprocessing
|
154
|
+
"""
|
155
|
+
if rank is not None or world_size is not None:
|
156
|
+
assert rank is not None and world_size is not None
|
157
|
+
root_logger = logging.getLogger()
|
158
|
+
|
159
|
+
config = load_user_config().logging
|
160
|
+
|
161
|
+
# Captures warnings from the warnings module.
|
162
|
+
logging.captureWarnings(True)
|
163
|
+
|
164
|
+
filter = RankFilter(rank=rank)
|
165
|
+
|
166
|
+
stream_handler = logging.StreamHandler(sys.stdout)
|
167
|
+
stream_handler.setFormatter(ColoredFormatter(prefix=prefix, rank=rank, world_size=world_size))
|
168
|
+
stream_handler.addFilter(filter)
|
169
|
+
root_logger.addHandler(stream_handler)
|
170
|
+
|
171
|
+
root_logger.setLevel(logging._nameToLevel[config.log_level])
|
172
|
+
|
173
|
+
# Avoid junk logs from other libraries.
|
174
|
+
if config.hide_third_party_logs:
|
175
|
+
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
176
|
+
logging.getLogger("PIL").setLevel(logging.WARNING)
|
177
|
+
logging.getLogger("torch").setLevel(logging.WARNING)
|
178
|
+
|
179
|
+
|
180
|
+
def get_unused_port(default: int | None = None) -> int:
|
181
|
+
"""Returns an unused port number on the local machine.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
default: A default port to try before trying other ports.
|
185
|
+
|
186
|
+
Returns:
|
187
|
+
A port number which is currently unused
|
188
|
+
"""
|
189
|
+
if default is not None:
|
190
|
+
sock = socket.socket()
|
191
|
+
try:
|
192
|
+
sock.bind(("", default))
|
193
|
+
return default
|
194
|
+
except OSError:
|
195
|
+
pass
|
196
|
+
finally:
|
197
|
+
sock.close()
|
198
|
+
|
199
|
+
sock = socket.socket()
|
200
|
+
sock.bind(("", 0))
|
201
|
+
return sock.getsockname()[1]
|
202
|
+
|
203
|
+
|
204
|
+
OmegaConf.register_new_resolver("mlfab.unused_port", get_unused_port, replace=True)
|
205
|
+
|
206
|
+
|
207
|
+
def port_is_busy(port: int) -> int:
|
208
|
+
"""Checks whether a port is busy.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
port: The port to check.
|
212
|
+
|
213
|
+
Returns:
|
214
|
+
Whether the port is busy.
|
215
|
+
"""
|
216
|
+
sock = socket.socket()
|
217
|
+
try:
|
218
|
+
sock.bind(("", port))
|
219
|
+
return False
|
220
|
+
except OSError:
|
221
|
+
return True
|
222
|
+
finally:
|
223
|
+
sock.close()
|
xax/utils/numpy.py
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
"""Defines some Numpy utility functions."""
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
|
6
|
+
def partial_flatten(x: np.ndarray) -> np.ndarray:
|
7
|
+
"""Flattens all but the first dimension of an array.
|
8
|
+
|
9
|
+
Args:
|
10
|
+
x: The array to flatten.
|
11
|
+
|
12
|
+
Returns:
|
13
|
+
The flattened array.
|
14
|
+
"""
|
15
|
+
return np.reshape(x, (x.shape[0], -1))
|
16
|
+
|
17
|
+
|
18
|
+
def one_hot(x: np.ndarray, k: int, dtype: type = np.float32) -> np.ndarray:
|
19
|
+
"""Converts an array of labels to a one-hot representation.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
x: The array of labels.
|
23
|
+
k: The number of classes.
|
24
|
+
dtype: The dtype of the returned array.
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
The one-hot representation of the labels.
|
28
|
+
"""
|
29
|
+
return np.array(x[:, None] == np.arange(k), dtype)
|
30
|
+
|
31
|
+
|
32
|
+
def worker_chunk(x: np.ndarray, worker_id: int, num_workers: int, dim: int = 0) -> np.ndarray:
|
33
|
+
"""Chunks an array into `num_workers` chunks.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
x: The array to chunk.
|
37
|
+
worker_id: The worker ID.
|
38
|
+
num_workers: The number of workers.
|
39
|
+
dim: The dimension to chunk along.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
The chunked array.
|
43
|
+
"""
|
44
|
+
chunk_size = x.shape[dim] // num_workers
|
45
|
+
start = worker_id * chunk_size
|
46
|
+
end = start + chunk_size
|
47
|
+
return x[start:end]
|
xax/utils/tensorboard.py
ADDED
@@ -0,0 +1,258 @@
|
|
1
|
+
"""Defines utility functions for interfacing with Tensorboard."""
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import io
|
5
|
+
import time
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Literal, TypedDict
|
8
|
+
|
9
|
+
from PIL.Image import Image as PILImage
|
10
|
+
from tensorboard.compat.proto.config_pb2 import RunMetadata
|
11
|
+
from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
|
12
|
+
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
13
|
+
from tensorboard.compat.proto.summary_pb2 import Summary, SummaryMetadata
|
14
|
+
from tensorboard.compat.proto.tensor_pb2 import TensorProto
|
15
|
+
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
16
|
+
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
|
17
|
+
from tensorboard.summary.writer.event_file_writer import EventFileWriter
|
18
|
+
|
19
|
+
from xax.core.state import Phase
|
20
|
+
|
21
|
+
ImageShape = Literal["HWC", "CHW", "HW", "NHWC", "NCHW", "NHW"]
|
22
|
+
|
23
|
+
|
24
|
+
class TensorboardProtobufWriter:
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
log_directory: str | Path,
|
28
|
+
max_queue_size: int = 10,
|
29
|
+
flush_seconds: float = 120.0,
|
30
|
+
filename_suffix: str = "",
|
31
|
+
) -> None:
|
32
|
+
super().__init__()
|
33
|
+
|
34
|
+
self.log_directory = Path(log_directory)
|
35
|
+
self.max_queue_size = max_queue_size
|
36
|
+
self.flush_seconds = flush_seconds
|
37
|
+
self.filename_suffix = filename_suffix
|
38
|
+
|
39
|
+
@functools.cached_property
|
40
|
+
def event_writer(self) -> EventFileWriter:
|
41
|
+
return EventFileWriter(
|
42
|
+
logdir=str(self.log_directory),
|
43
|
+
max_queue_size=self.max_queue_size,
|
44
|
+
flush_secs=self.flush_seconds,
|
45
|
+
filename_suffix=self.filename_suffix,
|
46
|
+
)
|
47
|
+
|
48
|
+
def add_event(
|
49
|
+
self,
|
50
|
+
event: Event,
|
51
|
+
step: int | None = None,
|
52
|
+
walltime: float | None = None,
|
53
|
+
) -> None:
|
54
|
+
event.wall_time = time.time() if walltime is None else walltime
|
55
|
+
if step is not None:
|
56
|
+
event.step = int(step)
|
57
|
+
self.event_writer.add_event(event)
|
58
|
+
|
59
|
+
def add_summary(
|
60
|
+
self,
|
61
|
+
summary: Summary,
|
62
|
+
global_step: int | None = None,
|
63
|
+
walltime: float | None = None,
|
64
|
+
) -> None:
|
65
|
+
event = Event(summary=summary)
|
66
|
+
self.add_event(event, step=global_step, walltime=walltime)
|
67
|
+
|
68
|
+
def add_graph(
|
69
|
+
self,
|
70
|
+
graph: GraphDef,
|
71
|
+
run_metadata: RunMetadata | None = None,
|
72
|
+
walltime: float | None = None,
|
73
|
+
) -> None:
|
74
|
+
event = Event(graph_def=graph.SerializeToString())
|
75
|
+
self.add_event(event, walltime=walltime)
|
76
|
+
if run_metadata is not None:
|
77
|
+
trm = TaggedRunMetadata(tag="step1", run_metadata=run_metadata.SerializeToString())
|
78
|
+
event = Event(tagged_run_metadata=trm)
|
79
|
+
self.add_event(event, walltime=walltime)
|
80
|
+
|
81
|
+
def flush(self) -> None:
|
82
|
+
self.event_writer.flush()
|
83
|
+
|
84
|
+
def close(self) -> None:
|
85
|
+
self.event_writer.close()
|
86
|
+
|
87
|
+
|
88
|
+
class TensorboardWriter:
|
89
|
+
"""Defines a class for writing artifacts to Tensorboard.
|
90
|
+
|
91
|
+
Parameters:
|
92
|
+
log_directory: The directory to write logs to.
|
93
|
+
max_queue_size: The maximum queue size.
|
94
|
+
flush_seconds: How often to flush logs.
|
95
|
+
filename_suffix: The filename suffix to use.
|
96
|
+
"""
|
97
|
+
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
log_directory: str | Path,
|
101
|
+
max_queue_size: int = 10,
|
102
|
+
flush_seconds: float = 120.0,
|
103
|
+
filename_suffix: str = "",
|
104
|
+
) -> None:
|
105
|
+
super().__init__()
|
106
|
+
|
107
|
+
self.pb_writer = TensorboardProtobufWriter(
|
108
|
+
log_directory=log_directory,
|
109
|
+
max_queue_size=max_queue_size,
|
110
|
+
flush_seconds=flush_seconds,
|
111
|
+
filename_suffix=filename_suffix,
|
112
|
+
)
|
113
|
+
|
114
|
+
def add_scalar(
|
115
|
+
self,
|
116
|
+
tag: str,
|
117
|
+
value: float,
|
118
|
+
global_step: int | None = None,
|
119
|
+
walltime: float | None = None,
|
120
|
+
new_style: bool = True,
|
121
|
+
double_precision: bool = False,
|
122
|
+
) -> None:
|
123
|
+
if new_style:
|
124
|
+
self.pb_writer.add_summary(
|
125
|
+
Summary(
|
126
|
+
value=[
|
127
|
+
Summary.Value(
|
128
|
+
tag=tag,
|
129
|
+
tensor=(
|
130
|
+
TensorProto(double_val=[value], dtype="DT_DOUBLE")
|
131
|
+
if double_precision
|
132
|
+
else TensorProto(float_val=[value], dtype="DT_FLOAT")
|
133
|
+
),
|
134
|
+
metadata=SummaryMetadata(
|
135
|
+
plugin_data=SummaryMetadata.PluginData(
|
136
|
+
plugin_name="scalars",
|
137
|
+
),
|
138
|
+
),
|
139
|
+
)
|
140
|
+
],
|
141
|
+
),
|
142
|
+
global_step=global_step,
|
143
|
+
walltime=walltime,
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
self.pb_writer.add_summary(
|
147
|
+
Summary(
|
148
|
+
value=[
|
149
|
+
Summary.Value(
|
150
|
+
tag=tag,
|
151
|
+
simple_value=value,
|
152
|
+
),
|
153
|
+
],
|
154
|
+
),
|
155
|
+
global_step=global_step,
|
156
|
+
walltime=walltime,
|
157
|
+
)
|
158
|
+
|
159
|
+
def add_image(
|
160
|
+
self,
|
161
|
+
tag: str,
|
162
|
+
value: PILImage,
|
163
|
+
global_step: int | None = None,
|
164
|
+
walltime: float | None = None,
|
165
|
+
) -> None:
|
166
|
+
output = io.BytesIO()
|
167
|
+
value.convert("RGB").save(output, format="PNG")
|
168
|
+
image_string = output.getvalue()
|
169
|
+
output.close()
|
170
|
+
|
171
|
+
self.pb_writer.add_summary(
|
172
|
+
Summary(
|
173
|
+
value=[
|
174
|
+
Summary.Value(
|
175
|
+
tag=tag,
|
176
|
+
image=Summary.Image(
|
177
|
+
height=value.height,
|
178
|
+
width=value.width,
|
179
|
+
colorspace=3, # RGB
|
180
|
+
encoded_image_string=image_string,
|
181
|
+
),
|
182
|
+
),
|
183
|
+
],
|
184
|
+
),
|
185
|
+
global_step=global_step,
|
186
|
+
walltime=walltime,
|
187
|
+
)
|
188
|
+
|
189
|
+
def add_text(
|
190
|
+
self,
|
191
|
+
tag: str,
|
192
|
+
value: str,
|
193
|
+
global_step: int | None = None,
|
194
|
+
walltime: float | None = None,
|
195
|
+
) -> None:
|
196
|
+
self.pb_writer.add_summary(
|
197
|
+
Summary(
|
198
|
+
value=[
|
199
|
+
Summary.Value(
|
200
|
+
tag=tag + "/text_summary",
|
201
|
+
metadata=SummaryMetadata(
|
202
|
+
plugin_data=SummaryMetadata.PluginData(
|
203
|
+
plugin_name="text", content=TextPluginData(version=0).SerializeToString()
|
204
|
+
),
|
205
|
+
),
|
206
|
+
tensor=TensorProto(
|
207
|
+
dtype="DT_STRING",
|
208
|
+
string_val=[value.encode(encoding="utf_8")],
|
209
|
+
tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
|
210
|
+
),
|
211
|
+
),
|
212
|
+
],
|
213
|
+
),
|
214
|
+
global_step=global_step,
|
215
|
+
walltime=walltime,
|
216
|
+
)
|
217
|
+
|
218
|
+
|
219
|
+
class TensorboardWriterKwargs(TypedDict):
|
220
|
+
max_queue_size: int
|
221
|
+
flush_seconds: float
|
222
|
+
filename_suffix: str
|
223
|
+
|
224
|
+
|
225
|
+
class TensorboardWriters:
|
226
|
+
def __init__(
|
227
|
+
self,
|
228
|
+
log_directory: str | Path,
|
229
|
+
max_queue_size: int = 10,
|
230
|
+
flush_seconds: float = 120.0,
|
231
|
+
filename_suffix: str = "",
|
232
|
+
) -> None:
|
233
|
+
super().__init__()
|
234
|
+
|
235
|
+
self.log_directory = Path(log_directory)
|
236
|
+
|
237
|
+
self.kwargs: TensorboardWriterKwargs = {
|
238
|
+
"max_queue_size": max_queue_size,
|
239
|
+
"flush_seconds": flush_seconds,
|
240
|
+
"filename_suffix": filename_suffix,
|
241
|
+
}
|
242
|
+
|
243
|
+
@functools.cached_property
|
244
|
+
def train_writer(self) -> TensorboardWriter:
|
245
|
+
return TensorboardWriter(self.log_directory / "train", **self.kwargs)
|
246
|
+
|
247
|
+
@functools.cached_property
|
248
|
+
def valid_writer(self) -> TensorboardWriter:
|
249
|
+
return TensorboardWriter(self.log_directory / "valid", **self.kwargs)
|
250
|
+
|
251
|
+
def writer(self, phase: Phase) -> TensorboardWriter:
|
252
|
+
match phase:
|
253
|
+
case "train":
|
254
|
+
return self.train_writer
|
255
|
+
case "valid":
|
256
|
+
return self.valid_writer
|
257
|
+
case _:
|
258
|
+
raise NotImplementedError(f"Unexpected phase: {phase}")
|