xax 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,56 @@
1
+ """Defines a logger that calls a callback function with the log line."""
2
+
3
+ from typing import Callable
4
+
5
+ from omegaconf import DictConfig
6
+
7
+ from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
8
+
9
+
10
+ class CallbackLogger(LoggerImpl):
11
+ def __init__(
12
+ self,
13
+ *,
14
+ callback: Callable[[LogLine], None] = lambda x: None,
15
+ error_summary_callback: Callable[[LogErrorSummary], None] = lambda x: None,
16
+ error_callback: Callable[[LogError], None] = lambda x: None,
17
+ status_callback: Callable[[LogStatus], None] = lambda x: None,
18
+ ping_callback: Callable[[LogPing], None] = lambda x: None,
19
+ git_state_callback: Callable[[str], None] = lambda x: None,
20
+ training_code_callback: Callable[[str], None] = lambda x: None,
21
+ config_callback: Callable[[DictConfig], None] = lambda x: None,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.callback = callback
26
+ self.error_summary_callback = error_summary_callback
27
+ self.error_callback = error_callback
28
+ self.status_callback = status_callback
29
+ self.ping_callback = ping_callback
30
+ self.git_state_callback = git_state_callback
31
+ self.training_code_callback = training_code_callback
32
+ self.config_callback = config_callback
33
+
34
+ def write(self, line: LogLine) -> None:
35
+ self.callback(line)
36
+
37
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
38
+ self.error_summary_callback(error_summary)
39
+
40
+ def write_error(self, error: LogError) -> None:
41
+ self.error_callback(error)
42
+
43
+ def write_status(self, status: LogStatus) -> None:
44
+ self.status_callback(status)
45
+
46
+ def write_ping(self, ping: LogPing) -> None:
47
+ self.ping_callback(ping)
48
+
49
+ def log_git_state(self, git_state: str) -> None:
50
+ self.git_state_callback(git_state)
51
+
52
+ def log_training_code(self, training_code: str) -> None:
53
+ self.training_code_callback(training_code)
54
+
55
+ def log_config(self, config: DictConfig) -> None:
56
+ self.config_callback(config)
@@ -12,8 +12,6 @@ import time
12
12
  from pathlib import Path
13
13
  from typing import TypeVar
14
14
 
15
- import jax
16
- import PIL.Image
17
15
  from omegaconf import DictConfig, OmegaConf
18
16
 
19
17
  from xax.core.state import Phase
@@ -84,7 +82,7 @@ class TensorboardLogger(LoggerImpl):
84
82
  port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
85
83
 
86
84
  while port_is_busy(port):
87
- logger.warning(f"Port {port} is busy, waiting...")
85
+ logger.warning("Port %s is busy, waiting...", port)
88
86
  time.sleep(10)
89
87
 
90
88
  def make_localhost(s: str) -> str:
@@ -205,10 +203,9 @@ class TensorboardLogger(LoggerImpl):
205
203
 
206
204
  for namespace, images in line.images.items():
207
205
  for image_key, image_value in images.items():
208
- image = PIL.Image.fromarray(jax.device_get(image_value.pixels))
209
206
  writer.add_image(
210
207
  f"{namespace}/{image_key}",
211
- image,
208
+ image_value.image,
212
209
  global_step=line.state.num_steps,
213
210
  walltime=walltime,
214
211
  )
@@ -1,6 +1,7 @@
1
1
  """Defines a single interface for all the mixins."""
2
2
 
3
3
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
4
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
4
5
  from xax.task.mixins.cpu_stats import CPUStatsConfig, CPUStatsMixin
5
6
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
6
7
  from xax.task.mixins.gpu_stats import GPUStatsConfig, GPUStatsMixin
@@ -8,4 +9,4 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
8
9
  from xax.task.mixins.process import ProcessConfig, ProcessMixin
9
10
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
10
11
  from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
11
- from xax.task.mixins.train import Batch, Model, Output, TrainConfig, TrainMixin
12
+ from xax.task.mixins.train import TrainConfig, TrainMixin
@@ -47,6 +47,10 @@ class ArtifactsMixin(BaseTask[Config]):
47
47
  self._exp_dir = exp_dir
48
48
  return self
49
49
 
50
+ @property
51
+ def exp_dir(self) -> Path:
52
+ return self.get_exp_dir()
53
+
50
54
  def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
51
55
  if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
52
56
  if not exists_ok:
@@ -61,13 +65,16 @@ class ArtifactsMixin(BaseTask[Config]):
61
65
  elif not missing_ok:
62
66
  raise RuntimeError(f"Lock file not found at {lock_file}")
63
67
 
64
- @functools.cached_property
65
- def exp_dir(self) -> Path:
68
+ def get_exp_dir(self) -> Path:
69
+ if self._exp_dir is not None:
70
+ return self._exp_dir
71
+
66
72
  if self.config.exp_dir is not None:
67
73
  exp_dir = Path(self.config.exp_dir).expanduser().resolve()
68
74
  exp_dir.mkdir(parents=True, exist_ok=True)
69
- logger.log(LOG_STATUS, self.exp_dir)
70
- return exp_dir
75
+ self._exp_dir = exp_dir
76
+ logger.log(LOG_STATUS, self._exp_dir)
77
+ return self._exp_dir
71
78
 
72
79
  def get_exp_dir(run_id: int) -> Path:
73
80
  return self.run_dir / f"run_{run_id}"
@@ -81,9 +88,9 @@ class ArtifactsMixin(BaseTask[Config]):
81
88
  while (exp_dir := get_exp_dir(run_id)).is_dir() and has_lock_file(exp_dir):
82
89
  run_id += 1
83
90
  exp_dir.mkdir(exist_ok=True, parents=True)
84
- exp_dir = exp_dir.expanduser().resolve()
85
- logger.log(LOG_STATUS, exp_dir)
86
- return exp_dir
91
+ self._exp_dir = exp_dir.expanduser().resolve()
92
+ logger.log(LOG_STATUS, self._exp_dir)
93
+ return self._exp_dir
87
94
 
88
95
  @functools.lru_cache(maxsize=None)
89
96
  def stage_environment(self) -> Path | None:
@@ -0,0 +1,209 @@
1
+ """Defines a mixin for handling model checkpointing."""
2
+
3
+ import io
4
+ import json
5
+ import logging
6
+ import tarfile
7
+ from dataclasses import asdict, dataclass
8
+ from pathlib import Path
9
+ from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
10
+
11
+ import cloudpickle
12
+ import optax
13
+ from jaxtyping import PyTree
14
+ from omegaconf import DictConfig, OmegaConf
15
+
16
+ from xax.core.conf import field
17
+ from xax.core.state import State
18
+ from xax.nn.parallel import is_master
19
+ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
24
+
25
+
26
+ def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
27
+ """Defines the path to the checkpoint for a given state.
28
+
29
+ Args:
30
+ exp_dir: The experiment directory
31
+ state: The current trainer state
32
+
33
+ Returns:
34
+ The path to the checkpoint file.
35
+ """
36
+ if state is None:
37
+ return exp_dir / "checkpoints" / "ckpt.bin"
38
+ return exp_dir / "checkpoints" / f"ckpt.{state.num_steps}.bin"
39
+
40
+
41
+ @dataclass
42
+ class CheckpointingConfig(ArtifactsConfig):
43
+ save_every_n_steps: int | None = field(None, help="Save a checkpoint every N steps")
44
+ save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
45
+ only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
46
+ load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
47
+ load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
48
+
49
+
50
+ Config = TypeVar("Config", bound=CheckpointingConfig)
51
+
52
+
53
+ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
54
+ def __init__(self, config: Config) -> None:
55
+ super().__init__(config)
56
+
57
+ self.__last_ckpt_time = 0.0
58
+
59
+ def get_ckpt_path(self, state: State | None = None) -> Path:
60
+ return get_ckpt_path(self.exp_dir, state)
61
+
62
+ def get_init_ckpt_path(self) -> Path | None:
63
+ if self._exp_dir is not None:
64
+ ckpt_path = self.get_ckpt_path()
65
+ if ckpt_path.exists():
66
+ return ckpt_path
67
+ if self.config.load_from_ckpt_path is not None:
68
+ ckpt_path = Path(self.config.load_from_ckpt_path)
69
+ assert ckpt_path.exists(), f"Checkpoint path {ckpt_path} does not exist."
70
+ return ckpt_path
71
+ return None
72
+
73
+ def should_checkpoint(self, state: State) -> bool:
74
+ if self.config.save_every_n_steps is not None:
75
+ if state.num_steps % self.config.save_every_n_steps == 0:
76
+ return True
77
+ if self.config.save_every_n_seconds is not None:
78
+ last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s
79
+ if cur_time - last_time >= self.config.save_every_n_seconds:
80
+ self.__last_ckpt_time = cur_time
81
+ return True
82
+ return False
83
+
84
+ @overload
85
+ def load_checkpoint(
86
+ self,
87
+ path: Path,
88
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
89
+
90
+ @overload
91
+ def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
92
+
93
+ @overload
94
+ def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
95
+
96
+ @overload
97
+ def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
98
+
99
+ @overload
100
+ def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
101
+
102
+ @overload
103
+ def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
104
+
105
+ def load_checkpoint(
106
+ self,
107
+ path: Path,
108
+ part: CheckpointPart | None = None,
109
+ ) -> (
110
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
111
+ | PyTree
112
+ | optax.GradientTransformation
113
+ | optax.OptState
114
+ | State
115
+ | DictConfig
116
+ ):
117
+ with tarfile.open(path, "r:gz") as tar:
118
+
119
+ def get_model() -> PyTree:
120
+ if (model := tar.extractfile("model")) is None:
121
+ raise ValueError(f"Checkpoint does not contain a model file: {path}")
122
+ return cloudpickle.load(model)
123
+
124
+ def get_opt() -> optax.GradientTransformation:
125
+ if (opt := tar.extractfile("opt")) is None:
126
+ raise ValueError(f"Checkpoint does not contain an opt file: {path}")
127
+ return cloudpickle.load(opt)
128
+
129
+ def get_opt_state() -> optax.OptState:
130
+ if (opt_state := tar.extractfile("opt_state")) is None:
131
+ raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
132
+ return cloudpickle.load(opt_state)
133
+
134
+ def get_state() -> State:
135
+ if (state := tar.extractfile("state")) is None:
136
+ raise ValueError(f"Checkpoint does not contain a state file: {path}")
137
+ return State(**json.loads(state.read().decode()))
138
+
139
+ def get_config() -> DictConfig:
140
+ if (config := tar.extractfile("config")) is None:
141
+ raise ValueError(f"Checkpoint does not contain a config file: {path}")
142
+ return cast(DictConfig, OmegaConf.load(config))
143
+
144
+ match part:
145
+ case "model":
146
+ return get_model()
147
+ case "opt":
148
+ return get_opt()
149
+ case "opt_state":
150
+ return get_opt_state()
151
+ case "state":
152
+ return get_state()
153
+ case "config":
154
+ return get_config()
155
+ case None:
156
+ return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
157
+ case _:
158
+ raise ValueError(f"Invalid checkpoint part: {part}")
159
+
160
+ def save_checkpoint(
161
+ self,
162
+ model: PyTree,
163
+ optimizer: optax.GradientTransformation,
164
+ opt_state: optax.OptState,
165
+ state: State,
166
+ ) -> Path:
167
+ ckpt_path = self.get_ckpt_path(state)
168
+
169
+ if not is_master():
170
+ return ckpt_path
171
+
172
+ # Gets the path to the last checkpoint.
173
+ logger.info("Saving checkpoint to %s", ckpt_path)
174
+ last_ckpt_path = self.get_ckpt_path()
175
+ ckpt_path.parent.mkdir(exist_ok=True, parents=True)
176
+
177
+ # Potentially removes the last checkpoint.
178
+ if last_ckpt_path.exists() and self.config.only_save_most_recent:
179
+ if (base_ckpt := last_ckpt_path.resolve()).is_file():
180
+ base_ckpt.unlink()
181
+
182
+ # Combines all temporary files into a single checkpoint TAR file.
183
+ with tarfile.open(ckpt_path, "w:gz") as tar:
184
+
185
+ def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
186
+ with io.BytesIO() as buf:
187
+ write_fn(buf)
188
+ tarinfo = tarfile.TarInfo(name)
189
+ tarinfo.size = buf.tell()
190
+ buf.seek(0)
191
+ tar.addfile(tarinfo, buf)
192
+
193
+ add_file("model", lambda buf: cloudpickle.dump(model, buf))
194
+ add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
195
+ add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
196
+ add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
197
+ add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
198
+
199
+ # Updates the symlink to the new checkpoint.
200
+ last_ckpt_path.unlink(missing_ok=True)
201
+ try:
202
+ last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
203
+ except FileExistsError:
204
+ logger.exception("Exception while trying to update %s", ckpt_path)
205
+
206
+ # Marks directory as having artifacts which shouldn't be overwritten.
207
+ self.add_lock_file("ckpt", exists_ok=True)
208
+
209
+ return ckpt_path
@@ -237,15 +237,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
237
237
  stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
238
238
 
239
239
  if stats is not None:
240
- self.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
241
- self.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
242
- self.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
243
- self.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
244
- self.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
245
- self.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
246
- self.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
247
- self.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
248
- self.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
249
- self.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
240
+ self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
241
+ self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
242
+ self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
243
+ self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
244
+ self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
245
+ self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
246
+ self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
247
+ self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
248
+ self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
249
+ self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
250
250
 
251
251
  return state
@@ -38,7 +38,6 @@ class DataloaderErrorConfig:
38
38
 
39
39
  @dataclass
40
40
  class DataloaderConfig:
41
- batch_size: int = field(MISSING, help="Size of each batch")
42
41
  num_workers: int | None = field(MISSING, help="Number of workers for loading samples")
43
42
  prefetch_factor: int = field(2, help="Number of items to pre-fetch on each worker")
44
43
  error: DataloaderErrorConfig = field(DataloaderErrorConfig(), help="Dataloader error configuration")
@@ -49,11 +48,11 @@ class DataloadersConfig(ProcessConfig, BaseConfig):
49
48
  batch_size: int = field(MISSING, help="Size of each batch")
50
49
  raise_dataloader_errors: bool = field(False, help="If set, raise dataloader errors inside the worker processes")
51
50
  train_dl: DataloaderConfig = field(
52
- DataloaderConfig(batch_size=II("batch_size")),
51
+ DataloaderConfig(num_workers=II("mlfab.num_workers:-1")),
53
52
  help="Train dataloader config",
54
53
  )
55
54
  valid_dl: DataloaderConfig = field(
56
- DataloaderConfig(batch_size=II("batch_size"), num_workers=1),
55
+ DataloaderConfig(num_workers=1),
57
56
  help="Valid dataloader config",
58
57
  )
59
58
  debug_dataloader: bool = field(False, help="Debug dataloaders")
@@ -64,9 +63,7 @@ Config = TypeVar("Config", bound=DataloadersConfig)
64
63
 
65
64
  class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config], ABC):
66
65
  def __init__(self, config: Config) -> None:
67
- if is_missing(config, "batch_size") and (
68
- is_missing(config.train_dl, "batch_size") or is_missing(config.valid_dl, "batch_size")
69
- ):
66
+ if is_missing(config, "batch_size"):
70
67
  config.batch_size = self.get_batch_size()
71
68
 
72
69
  super().__init__(config)
@@ -120,10 +117,10 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
120
117
 
121
118
  return Dataloader(
122
119
  dataset=dataset,
123
- batch_size=cfg.batch_size,
120
+ batch_size=self.config.batch_size,
124
121
  num_workers=0 if debugging else cfg.num_workers,
125
122
  prefetch_factor=cfg.prefetch_factor,
126
- ctx=self.multiprocessing_context,
123
+ mp_manager=self.multiprocessing_manager,
127
124
  dataloader_worker_init_fn=self.dataloader_worker_init_fn,
128
125
  collate_worker_init_fn=self.collate_worker_init_fn,
129
126
  item_callback=self.dataloader_item_callback,
@@ -135,7 +132,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
135
132
 
136
133
  @classmethod
137
134
  def to_device_fn(cls, sample: T) -> T:
138
- return recursive_apply(sample, jax.device_put)
135
+ return recursive_apply(sample, jax.device_put, include_numpy=True)
139
136
 
140
137
  @classmethod
141
138
  def dataloader_worker_init_fn(cls, worker_id: int, num_workers: int) -> None:
@@ -250,8 +250,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
250
250
  for gpu_stat in stats.values():
251
251
  if gpu_stat is None:
252
252
  continue
253
- self.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
254
- self.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
255
- self.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
253
+ self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
254
+ self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
255
+ self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
256
256
 
257
257
  return state