xax 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +49 -7
- xax/core/conf.py +1 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +15 -10
- xax/task/base.py +0 -6
- xax/task/logger.py +328 -393
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/tensorboard.py +2 -5
- xax/task/mixins/__init__.py +2 -1
- xax/task/mixins/artifacts.py +14 -7
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +10 -10
- xax/task/mixins/data_loader.py +6 -9
- xax/task/mixins/gpu_stats.py +3 -3
- xax/task/mixins/logger.py +2 -250
- xax/task/mixins/process.py +4 -0
- xax/task/mixins/train.py +71 -40
- xax/task/task.py +6 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +49 -29
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/METADATA +15 -14
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,56 @@
|
|
1
|
+
"""Defines a logger that calls a callback function with the log line."""
|
2
|
+
|
3
|
+
from typing import Callable
|
4
|
+
|
5
|
+
from omegaconf import DictConfig
|
6
|
+
|
7
|
+
from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
|
8
|
+
|
9
|
+
|
10
|
+
class CallbackLogger(LoggerImpl):
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
*,
|
14
|
+
callback: Callable[[LogLine], None] = lambda x: None,
|
15
|
+
error_summary_callback: Callable[[LogErrorSummary], None] = lambda x: None,
|
16
|
+
error_callback: Callable[[LogError], None] = lambda x: None,
|
17
|
+
status_callback: Callable[[LogStatus], None] = lambda x: None,
|
18
|
+
ping_callback: Callable[[LogPing], None] = lambda x: None,
|
19
|
+
git_state_callback: Callable[[str], None] = lambda x: None,
|
20
|
+
training_code_callback: Callable[[str], None] = lambda x: None,
|
21
|
+
config_callback: Callable[[DictConfig], None] = lambda x: None,
|
22
|
+
) -> None:
|
23
|
+
super().__init__()
|
24
|
+
|
25
|
+
self.callback = callback
|
26
|
+
self.error_summary_callback = error_summary_callback
|
27
|
+
self.error_callback = error_callback
|
28
|
+
self.status_callback = status_callback
|
29
|
+
self.ping_callback = ping_callback
|
30
|
+
self.git_state_callback = git_state_callback
|
31
|
+
self.training_code_callback = training_code_callback
|
32
|
+
self.config_callback = config_callback
|
33
|
+
|
34
|
+
def write(self, line: LogLine) -> None:
|
35
|
+
self.callback(line)
|
36
|
+
|
37
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
38
|
+
self.error_summary_callback(error_summary)
|
39
|
+
|
40
|
+
def write_error(self, error: LogError) -> None:
|
41
|
+
self.error_callback(error)
|
42
|
+
|
43
|
+
def write_status(self, status: LogStatus) -> None:
|
44
|
+
self.status_callback(status)
|
45
|
+
|
46
|
+
def write_ping(self, ping: LogPing) -> None:
|
47
|
+
self.ping_callback(ping)
|
48
|
+
|
49
|
+
def log_git_state(self, git_state: str) -> None:
|
50
|
+
self.git_state_callback(git_state)
|
51
|
+
|
52
|
+
def log_training_code(self, training_code: str) -> None:
|
53
|
+
self.training_code_callback(training_code)
|
54
|
+
|
55
|
+
def log_config(self, config: DictConfig) -> None:
|
56
|
+
self.config_callback(config)
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -12,8 +12,6 @@ import time
|
|
12
12
|
from pathlib import Path
|
13
13
|
from typing import TypeVar
|
14
14
|
|
15
|
-
import jax
|
16
|
-
import PIL.Image
|
17
15
|
from omegaconf import DictConfig, OmegaConf
|
18
16
|
|
19
17
|
from xax.core.state import Phase
|
@@ -84,7 +82,7 @@ class TensorboardLogger(LoggerImpl):
|
|
84
82
|
port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
|
85
83
|
|
86
84
|
while port_is_busy(port):
|
87
|
-
logger.warning(
|
85
|
+
logger.warning("Port %s is busy, waiting...", port)
|
88
86
|
time.sleep(10)
|
89
87
|
|
90
88
|
def make_localhost(s: str) -> str:
|
@@ -205,10 +203,9 @@ class TensorboardLogger(LoggerImpl):
|
|
205
203
|
|
206
204
|
for namespace, images in line.images.items():
|
207
205
|
for image_key, image_value in images.items():
|
208
|
-
image = PIL.Image.fromarray(jax.device_get(image_value.pixels))
|
209
206
|
writer.add_image(
|
210
207
|
f"{namespace}/{image_key}",
|
211
|
-
image,
|
208
|
+
image_value.image,
|
212
209
|
global_step=line.state.num_steps,
|
213
210
|
walltime=walltime,
|
214
211
|
)
|
xax/task/mixins/__init__.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Defines a single interface for all the mixins."""
|
2
2
|
|
3
3
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
4
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
4
5
|
from xax.task.mixins.cpu_stats import CPUStatsConfig, CPUStatsMixin
|
5
6
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
6
7
|
from xax.task.mixins.gpu_stats import GPUStatsConfig, GPUStatsMixin
|
@@ -8,4 +9,4 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
|
8
9
|
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
9
10
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
10
11
|
from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
11
|
-
from xax.task.mixins.train import
|
12
|
+
from xax.task.mixins.train import TrainConfig, TrainMixin
|
xax/task/mixins/artifacts.py
CHANGED
@@ -47,6 +47,10 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
47
47
|
self._exp_dir = exp_dir
|
48
48
|
return self
|
49
49
|
|
50
|
+
@property
|
51
|
+
def exp_dir(self) -> Path:
|
52
|
+
return self.get_exp_dir()
|
53
|
+
|
50
54
|
def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
|
51
55
|
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
52
56
|
if not exists_ok:
|
@@ -61,13 +65,16 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
61
65
|
elif not missing_ok:
|
62
66
|
raise RuntimeError(f"Lock file not found at {lock_file}")
|
63
67
|
|
64
|
-
|
65
|
-
|
68
|
+
def get_exp_dir(self) -> Path:
|
69
|
+
if self._exp_dir is not None:
|
70
|
+
return self._exp_dir
|
71
|
+
|
66
72
|
if self.config.exp_dir is not None:
|
67
73
|
exp_dir = Path(self.config.exp_dir).expanduser().resolve()
|
68
74
|
exp_dir.mkdir(parents=True, exist_ok=True)
|
69
|
-
|
70
|
-
|
75
|
+
self._exp_dir = exp_dir
|
76
|
+
logger.log(LOG_STATUS, self._exp_dir)
|
77
|
+
return self._exp_dir
|
71
78
|
|
72
79
|
def get_exp_dir(run_id: int) -> Path:
|
73
80
|
return self.run_dir / f"run_{run_id}"
|
@@ -81,9 +88,9 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
81
88
|
while (exp_dir := get_exp_dir(run_id)).is_dir() and has_lock_file(exp_dir):
|
82
89
|
run_id += 1
|
83
90
|
exp_dir.mkdir(exist_ok=True, parents=True)
|
84
|
-
|
85
|
-
logger.log(LOG_STATUS,
|
86
|
-
return
|
91
|
+
self._exp_dir = exp_dir.expanduser().resolve()
|
92
|
+
logger.log(LOG_STATUS, self._exp_dir)
|
93
|
+
return self._exp_dir
|
87
94
|
|
88
95
|
@functools.lru_cache(maxsize=None)
|
89
96
|
def stage_environment(self) -> Path | None:
|
@@ -0,0 +1,209 @@
|
|
1
|
+
"""Defines a mixin for handling model checkpointing."""
|
2
|
+
|
3
|
+
import io
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import tarfile
|
7
|
+
from dataclasses import asdict, dataclass
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
|
10
|
+
|
11
|
+
import cloudpickle
|
12
|
+
import optax
|
13
|
+
from jaxtyping import PyTree
|
14
|
+
from omegaconf import DictConfig, OmegaConf
|
15
|
+
|
16
|
+
from xax.core.conf import field
|
17
|
+
from xax.core.state import State
|
18
|
+
from xax.nn.parallel import is_master
|
19
|
+
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
|
24
|
+
|
25
|
+
|
26
|
+
def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
|
27
|
+
"""Defines the path to the checkpoint for a given state.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
exp_dir: The experiment directory
|
31
|
+
state: The current trainer state
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The path to the checkpoint file.
|
35
|
+
"""
|
36
|
+
if state is None:
|
37
|
+
return exp_dir / "checkpoints" / "ckpt.bin"
|
38
|
+
return exp_dir / "checkpoints" / f"ckpt.{state.num_steps}.bin"
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass
|
42
|
+
class CheckpointingConfig(ArtifactsConfig):
|
43
|
+
save_every_n_steps: int | None = field(None, help="Save a checkpoint every N steps")
|
44
|
+
save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
|
45
|
+
only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
|
46
|
+
load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
|
47
|
+
load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
|
48
|
+
|
49
|
+
|
50
|
+
Config = TypeVar("Config", bound=CheckpointingConfig)
|
51
|
+
|
52
|
+
|
53
|
+
class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
54
|
+
def __init__(self, config: Config) -> None:
|
55
|
+
super().__init__(config)
|
56
|
+
|
57
|
+
self.__last_ckpt_time = 0.0
|
58
|
+
|
59
|
+
def get_ckpt_path(self, state: State | None = None) -> Path:
|
60
|
+
return get_ckpt_path(self.exp_dir, state)
|
61
|
+
|
62
|
+
def get_init_ckpt_path(self) -> Path | None:
|
63
|
+
if self._exp_dir is not None:
|
64
|
+
ckpt_path = self.get_ckpt_path()
|
65
|
+
if ckpt_path.exists():
|
66
|
+
return ckpt_path
|
67
|
+
if self.config.load_from_ckpt_path is not None:
|
68
|
+
ckpt_path = Path(self.config.load_from_ckpt_path)
|
69
|
+
assert ckpt_path.exists(), f"Checkpoint path {ckpt_path} does not exist."
|
70
|
+
return ckpt_path
|
71
|
+
return None
|
72
|
+
|
73
|
+
def should_checkpoint(self, state: State) -> bool:
|
74
|
+
if self.config.save_every_n_steps is not None:
|
75
|
+
if state.num_steps % self.config.save_every_n_steps == 0:
|
76
|
+
return True
|
77
|
+
if self.config.save_every_n_seconds is not None:
|
78
|
+
last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s
|
79
|
+
if cur_time - last_time >= self.config.save_every_n_seconds:
|
80
|
+
self.__last_ckpt_time = cur_time
|
81
|
+
return True
|
82
|
+
return False
|
83
|
+
|
84
|
+
@overload
|
85
|
+
def load_checkpoint(
|
86
|
+
self,
|
87
|
+
path: Path,
|
88
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
89
|
+
|
90
|
+
@overload
|
91
|
+
def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
|
92
|
+
|
93
|
+
@overload
|
94
|
+
def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
|
95
|
+
|
96
|
+
@overload
|
97
|
+
def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
|
98
|
+
|
99
|
+
@overload
|
100
|
+
def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
|
101
|
+
|
102
|
+
@overload
|
103
|
+
def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
|
104
|
+
|
105
|
+
def load_checkpoint(
|
106
|
+
self,
|
107
|
+
path: Path,
|
108
|
+
part: CheckpointPart | None = None,
|
109
|
+
) -> (
|
110
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
111
|
+
| PyTree
|
112
|
+
| optax.GradientTransformation
|
113
|
+
| optax.OptState
|
114
|
+
| State
|
115
|
+
| DictConfig
|
116
|
+
):
|
117
|
+
with tarfile.open(path, "r:gz") as tar:
|
118
|
+
|
119
|
+
def get_model() -> PyTree:
|
120
|
+
if (model := tar.extractfile("model")) is None:
|
121
|
+
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
122
|
+
return cloudpickle.load(model)
|
123
|
+
|
124
|
+
def get_opt() -> optax.GradientTransformation:
|
125
|
+
if (opt := tar.extractfile("opt")) is None:
|
126
|
+
raise ValueError(f"Checkpoint does not contain an opt file: {path}")
|
127
|
+
return cloudpickle.load(opt)
|
128
|
+
|
129
|
+
def get_opt_state() -> optax.OptState:
|
130
|
+
if (opt_state := tar.extractfile("opt_state")) is None:
|
131
|
+
raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
|
132
|
+
return cloudpickle.load(opt_state)
|
133
|
+
|
134
|
+
def get_state() -> State:
|
135
|
+
if (state := tar.extractfile("state")) is None:
|
136
|
+
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
137
|
+
return State(**json.loads(state.read().decode()))
|
138
|
+
|
139
|
+
def get_config() -> DictConfig:
|
140
|
+
if (config := tar.extractfile("config")) is None:
|
141
|
+
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
142
|
+
return cast(DictConfig, OmegaConf.load(config))
|
143
|
+
|
144
|
+
match part:
|
145
|
+
case "model":
|
146
|
+
return get_model()
|
147
|
+
case "opt":
|
148
|
+
return get_opt()
|
149
|
+
case "opt_state":
|
150
|
+
return get_opt_state()
|
151
|
+
case "state":
|
152
|
+
return get_state()
|
153
|
+
case "config":
|
154
|
+
return get_config()
|
155
|
+
case None:
|
156
|
+
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
157
|
+
case _:
|
158
|
+
raise ValueError(f"Invalid checkpoint part: {part}")
|
159
|
+
|
160
|
+
def save_checkpoint(
|
161
|
+
self,
|
162
|
+
model: PyTree,
|
163
|
+
optimizer: optax.GradientTransformation,
|
164
|
+
opt_state: optax.OptState,
|
165
|
+
state: State,
|
166
|
+
) -> Path:
|
167
|
+
ckpt_path = self.get_ckpt_path(state)
|
168
|
+
|
169
|
+
if not is_master():
|
170
|
+
return ckpt_path
|
171
|
+
|
172
|
+
# Gets the path to the last checkpoint.
|
173
|
+
logger.info("Saving checkpoint to %s", ckpt_path)
|
174
|
+
last_ckpt_path = self.get_ckpt_path()
|
175
|
+
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
176
|
+
|
177
|
+
# Potentially removes the last checkpoint.
|
178
|
+
if last_ckpt_path.exists() and self.config.only_save_most_recent:
|
179
|
+
if (base_ckpt := last_ckpt_path.resolve()).is_file():
|
180
|
+
base_ckpt.unlink()
|
181
|
+
|
182
|
+
# Combines all temporary files into a single checkpoint TAR file.
|
183
|
+
with tarfile.open(ckpt_path, "w:gz") as tar:
|
184
|
+
|
185
|
+
def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
|
186
|
+
with io.BytesIO() as buf:
|
187
|
+
write_fn(buf)
|
188
|
+
tarinfo = tarfile.TarInfo(name)
|
189
|
+
tarinfo.size = buf.tell()
|
190
|
+
buf.seek(0)
|
191
|
+
tar.addfile(tarinfo, buf)
|
192
|
+
|
193
|
+
add_file("model", lambda buf: cloudpickle.dump(model, buf))
|
194
|
+
add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
|
195
|
+
add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
|
196
|
+
add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
|
197
|
+
add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
|
198
|
+
|
199
|
+
# Updates the symlink to the new checkpoint.
|
200
|
+
last_ckpt_path.unlink(missing_ok=True)
|
201
|
+
try:
|
202
|
+
last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
|
203
|
+
except FileExistsError:
|
204
|
+
logger.exception("Exception while trying to update %s", ckpt_path)
|
205
|
+
|
206
|
+
# Marks directory as having artifacts which shouldn't be overwritten.
|
207
|
+
self.add_lock_file("ckpt", exists_ok=True)
|
208
|
+
|
209
|
+
return ckpt_path
|
xax/task/mixins/cpu_stats.py
CHANGED
@@ -237,15 +237,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
237
237
|
stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
|
238
238
|
|
239
239
|
if stats is not None:
|
240
|
-
self.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
|
241
|
-
self.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
|
242
|
-
self.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
|
243
|
-
self.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
|
244
|
-
self.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
|
245
|
-
self.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
|
246
|
-
self.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
|
247
|
-
self.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
|
248
|
-
self.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
|
249
|
-
self.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
|
240
|
+
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
|
241
|
+
self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
|
242
|
+
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
|
243
|
+
self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
|
244
|
+
self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
|
245
|
+
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
|
246
|
+
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
|
247
|
+
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
|
248
|
+
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
|
249
|
+
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
|
250
250
|
|
251
251
|
return state
|
xax/task/mixins/data_loader.py
CHANGED
@@ -38,7 +38,6 @@ class DataloaderErrorConfig:
|
|
38
38
|
|
39
39
|
@dataclass
|
40
40
|
class DataloaderConfig:
|
41
|
-
batch_size: int = field(MISSING, help="Size of each batch")
|
42
41
|
num_workers: int | None = field(MISSING, help="Number of workers for loading samples")
|
43
42
|
prefetch_factor: int = field(2, help="Number of items to pre-fetch on each worker")
|
44
43
|
error: DataloaderErrorConfig = field(DataloaderErrorConfig(), help="Dataloader error configuration")
|
@@ -49,11 +48,11 @@ class DataloadersConfig(ProcessConfig, BaseConfig):
|
|
49
48
|
batch_size: int = field(MISSING, help="Size of each batch")
|
50
49
|
raise_dataloader_errors: bool = field(False, help="If set, raise dataloader errors inside the worker processes")
|
51
50
|
train_dl: DataloaderConfig = field(
|
52
|
-
DataloaderConfig(
|
51
|
+
DataloaderConfig(num_workers=II("mlfab.num_workers:-1")),
|
53
52
|
help="Train dataloader config",
|
54
53
|
)
|
55
54
|
valid_dl: DataloaderConfig = field(
|
56
|
-
DataloaderConfig(
|
55
|
+
DataloaderConfig(num_workers=1),
|
57
56
|
help="Valid dataloader config",
|
58
57
|
)
|
59
58
|
debug_dataloader: bool = field(False, help="Debug dataloaders")
|
@@ -64,9 +63,7 @@ Config = TypeVar("Config", bound=DataloadersConfig)
|
|
64
63
|
|
65
64
|
class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config], ABC):
|
66
65
|
def __init__(self, config: Config) -> None:
|
67
|
-
if is_missing(config, "batch_size")
|
68
|
-
is_missing(config.train_dl, "batch_size") or is_missing(config.valid_dl, "batch_size")
|
69
|
-
):
|
66
|
+
if is_missing(config, "batch_size"):
|
70
67
|
config.batch_size = self.get_batch_size()
|
71
68
|
|
72
69
|
super().__init__(config)
|
@@ -120,10 +117,10 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
120
117
|
|
121
118
|
return Dataloader(
|
122
119
|
dataset=dataset,
|
123
|
-
batch_size=
|
120
|
+
batch_size=self.config.batch_size,
|
124
121
|
num_workers=0 if debugging else cfg.num_workers,
|
125
122
|
prefetch_factor=cfg.prefetch_factor,
|
126
|
-
|
123
|
+
mp_manager=self.multiprocessing_manager,
|
127
124
|
dataloader_worker_init_fn=self.dataloader_worker_init_fn,
|
128
125
|
collate_worker_init_fn=self.collate_worker_init_fn,
|
129
126
|
item_callback=self.dataloader_item_callback,
|
@@ -135,7 +132,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
135
132
|
|
136
133
|
@classmethod
|
137
134
|
def to_device_fn(cls, sample: T) -> T:
|
138
|
-
return recursive_apply(sample, jax.device_put)
|
135
|
+
return recursive_apply(sample, jax.device_put, include_numpy=True)
|
139
136
|
|
140
137
|
@classmethod
|
141
138
|
def dataloader_worker_init_fn(cls, worker_id: int, num_workers: int) -> None:
|
xax/task/mixins/gpu_stats.py
CHANGED
@@ -250,8 +250,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
250
250
|
for gpu_stat in stats.values():
|
251
251
|
if gpu_stat is None:
|
252
252
|
continue
|
253
|
-
self.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
|
254
|
-
self.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
|
255
|
-
self.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
|
253
|
+
self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
|
254
|
+
self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
|
255
|
+
self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
|
256
256
|
|
257
257
|
return state
|