xax 0.0.7__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/logger.py CHANGED
@@ -223,10 +223,29 @@ class LogVideo:
223
223
  fps: int
224
224
 
225
225
 
226
+ @dataclass(kw_only=True)
227
+ class LogDistribution:
228
+ mean: Number
229
+ std: Number
230
+
231
+
232
+ @dataclass(kw_only=True)
233
+ class LogHistogram:
234
+ min: Number
235
+ max: Number
236
+ num: int
237
+ sum: Number
238
+ sum_squares: Number
239
+ bucket_limits: list[Number]
240
+ bucket_counts: list[int]
241
+
242
+
226
243
  @dataclass(kw_only=True)
227
244
  class LogLine:
228
245
  state: State
229
246
  scalars: dict[str, dict[str, Number]]
247
+ distributions: dict[str, dict[str, LogDistribution]]
248
+ histograms: dict[str, dict[str, LogHistogram]]
230
249
  strings: dict[str, dict[str, str]]
231
250
  images: dict[str, dict[str, LogImage]]
232
251
  videos: dict[str, dict[str, LogVideo]]
@@ -329,9 +348,9 @@ def image_with_text(
329
348
  else:
330
349
  text = text[:max_num_lines]
331
350
  width, height = image.size
332
- font: ImageFont.ImageFont = ImageFont.load_default()
351
+ font: ImageFont.ImageFont | ImageFont.FreeTypeFont = ImageFont.load_default()
333
352
  _, _, _, line_height = font.getbbox(text[0])
334
- new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
353
+ new_width, new_height = width, int(height + line_spacing + max_num_lines * (line_height + line_spacing))
335
354
  padded_image = Image.new(image.mode, (new_width, new_height), 255)
336
355
  padded_image.paste(image, (0, 0))
337
356
  drawer = ImageDraw.Draw(padded_image)
@@ -497,6 +516,8 @@ class Logger:
497
516
 
498
517
  def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
499
518
  self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
519
+ self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
520
+ self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
500
521
  self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
501
522
  self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
502
523
  self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
@@ -522,6 +543,8 @@ class Logger:
522
543
  return LogLine(
523
544
  state=state,
524
545
  scalars={k: {kk: v() for kk, v in v.items()} for k, v in self.scalars.items()},
546
+ distributions={k: {kk: v() for kk, v in v.items()} for k, v in self.distributions.items()},
547
+ histograms={k: {kk: v() for kk, v in v.items()} for k, v in self.histograms.items()},
525
548
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
526
549
  images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
527
550
  videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
@@ -529,6 +552,8 @@ class Logger:
529
552
 
530
553
  def clear(self) -> None:
531
554
  self.scalars.clear()
555
+ self.distributions.clear()
556
+ self.histograms.clear()
532
557
  self.strings.clear()
533
558
  self.images.clear()
534
559
  self.videos.clear()
@@ -612,6 +637,76 @@ class Logger:
612
637
 
613
638
  self.scalars[namespace][key] = scalar_future
614
639
 
640
+ def log_distribution(
641
+ self,
642
+ key: str,
643
+ value: Callable[[], tuple[Number, Number]] | tuple[Number, Number],
644
+ *,
645
+ namespace: str | None = None,
646
+ ) -> None:
647
+ """Logs a distribution value.
648
+
649
+ Args:
650
+ key: The key being logged
651
+ value: The distribution value being logged, a tuple of (mean, std)
652
+ namespace: An optional logging namespace
653
+ """
654
+ if not self.active:
655
+ raise RuntimeError("The logger is not active")
656
+ namespace = self.resolve_namespace(namespace)
657
+
658
+ @functools.lru_cache(maxsize=None)
659
+ def distribution_future() -> LogDistribution:
660
+ mean, std = value() if callable(value) else value
661
+ return LogDistribution(mean=mean, std=std)
662
+
663
+ self.distributions[namespace][key] = distribution_future
664
+
665
+ def log_histogram(
666
+ self,
667
+ key: str,
668
+ value: Callable[[], np.ndarray | Array] | np.ndarray | Array,
669
+ *,
670
+ bins: int = 100,
671
+ namespace: str | None = None,
672
+ ) -> None:
673
+ """Logs a histogram value.
674
+
675
+ Args:
676
+ key: The key being logged
677
+ value: The histogram value being logged
678
+ bins: The number of bins to use for the histogram
679
+ namespace: An optional logging namespace
680
+ """
681
+ if not self.active:
682
+ raise RuntimeError("The logger is not active")
683
+ namespace = self.resolve_namespace(namespace)
684
+
685
+ @functools.lru_cache(maxsize=None)
686
+ def histogram_future() -> LogHistogram:
687
+ values = value() if callable(value) else value
688
+ values = values.reshape(-1) # Must be flat.
689
+
690
+ if isinstance(values, Array):
691
+ counts, limits = jnp.histogram(values, bins=bins)
692
+ counts, limits = as_numpy(counts), as_numpy(limits)
693
+ elif isinstance(values, np.ndarray):
694
+ counts, limits = np.histogram(values, bins=bins)
695
+ else:
696
+ raise ValueError(f"Unsupported histogram type: {type(values)}")
697
+
698
+ return LogHistogram(
699
+ min=float(values.min()),
700
+ max=float(values.max()),
701
+ num=int(values.size),
702
+ sum=float(values.sum()),
703
+ sum_squares=float(values.dot(values)),
704
+ bucket_limits=limits[1:].tolist(),
705
+ bucket_counts=counts.tolist(),
706
+ )
707
+
708
+ self.histograms[namespace][key] = histogram_future
709
+
615
710
  def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
616
711
  """Logs a string value.
617
712
 
@@ -33,7 +33,7 @@ class StdoutLogger(LoggerImpl):
33
33
  self,
34
34
  write_fp: TextIO = sys.stdout,
35
35
  precision: int = 4,
36
- log_timers: bool = False,
36
+ log_timers: bool = True,
37
37
  log_perf: bool = False,
38
38
  log_optim: bool = False,
39
39
  log_fp: bool = False,
@@ -98,7 +98,7 @@ class StdoutLogger(LoggerImpl):
98
98
 
99
99
  def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
100
100
  for namespace, values in log.items():
101
- if not self.log_timers and namespace.startswith(""):
101
+ if not self.log_timers and namespace.startswith(""):
102
102
  continue
103
103
  if not self.log_perf and namespace.startswith("🔧"):
104
104
  continue
@@ -1,11 +1,9 @@
1
1
  """Defines a Tensorboard logger backend."""
2
2
 
3
3
  import atexit
4
- import functools
5
4
  import logging
6
5
  import os
7
6
  import re
8
- import shutil
9
7
  import subprocess
10
8
  import threading
11
9
  import time
@@ -140,15 +138,6 @@ class TensorboardLogger(LoggerImpl):
140
138
  def __del__(self) -> None:
141
139
  self.cleanup()
142
140
 
143
- @functools.lru_cache(None) # Avoid clearing logs multiple times.
144
- def clear_logs(self) -> None:
145
- if not self.log_directory.exists():
146
- return
147
- if not any(child.is_dir() for child in self.log_directory.iterdir()):
148
- return
149
- logger.warning("Clearing TensorBoard logs")
150
- shutil.rmtree(self.log_directory)
151
-
152
141
  def get_writer(self, phase: Phase) -> TensorboardWriter:
153
142
  self._start()
154
143
  return self.writers.writer(phase)
@@ -162,9 +151,6 @@ class TensorboardLogger(LoggerImpl):
162
151
  if not is_master():
163
152
  return
164
153
 
165
- if line.state.num_steps == 0:
166
- self.clear_logs()
167
-
168
154
  writer = self.get_writer(line.state.phase)
169
155
  walltime = line.state.start_time_s + line.state.elapsed_time_s
170
156
 
@@ -177,6 +163,31 @@ class TensorboardLogger(LoggerImpl):
177
163
  walltime=walltime,
178
164
  )
179
165
 
166
+ for namespace, distributions in line.distributions.items():
167
+ for distribution_key, distribution_value in distributions.items():
168
+ writer.add_gaussian_distribution(
169
+ f"{namespace}/{distribution_key}",
170
+ mean=float(distribution_value.mean),
171
+ std=float(distribution_value.std),
172
+ global_step=line.state.num_steps,
173
+ walltime=walltime,
174
+ )
175
+
176
+ for namespace, histograms in line.histograms.items():
177
+ for histogram_key, histogram_value in histograms.items():
178
+ writer.add_histogram_raw(
179
+ f"{namespace}/{histogram_key}",
180
+ min=float(histogram_value.min),
181
+ max=float(histogram_value.max),
182
+ num=int(histogram_value.num),
183
+ sum=float(histogram_value.sum),
184
+ sum_squares=float(histogram_value.sum_squares),
185
+ bucket_limits=[float(x) for x in histogram_value.bucket_limits],
186
+ bucket_counts=[int(x) for x in histogram_value.bucket_counts],
187
+ global_step=line.state.num_steps,
188
+ walltime=walltime,
189
+ )
190
+
180
191
  for namespace, strings in line.strings.items():
181
192
  for string_key, string_value in strings.items():
182
193
  writer.add_text(
@@ -3,7 +3,6 @@
3
3
  import functools
4
4
  import inspect
5
5
  import logging
6
- import os
7
6
  from dataclasses import dataclass
8
7
  from pathlib import Path
9
8
  from typing import Self, TypeVar
@@ -54,20 +53,6 @@ class ArtifactsMixin(BaseTask[Config]):
54
53
  self._exp_dir = Path(exp_dir).expanduser().resolve()
55
54
  return self
56
55
 
57
- def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
58
- if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
59
- if not exists_ok:
60
- raise RuntimeError(f"Lock file already exists at {lock_file}")
61
- else:
62
- with open(lock_file, "w", encoding="utf-8") as f:
63
- f.write(f"PID: {os.getpid()}")
64
-
65
- def remove_lock_file(self, lock_type: str, *, missing_ok: bool = False) -> None:
66
- if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
67
- lock_file.unlink()
68
- elif not missing_ok:
69
- raise RuntimeError(f"Lock file not found at {lock_file}")
70
-
71
56
  def get_exp_dir(self) -> Path:
72
57
  if self._exp_dir is not None:
73
58
  return self._exp_dir
@@ -82,13 +67,8 @@ class ArtifactsMixin(BaseTask[Config]):
82
67
  def get_exp_dir(run_id: int) -> Path:
83
68
  return self.run_dir / f"run_{run_id}"
84
69
 
85
- def has_lock_file(exp_dir: Path, lock_type: str | None = None) -> bool:
86
- if lock_type is not None:
87
- return (exp_dir / f".lock_{lock_type}").exists()
88
- return any(exp_dir.glob(".lock_*"))
89
-
90
70
  run_id = 0
91
- while (exp_dir := get_exp_dir(run_id)).is_dir() and has_lock_file(exp_dir):
71
+ while (exp_dir := get_exp_dir(run_id)).is_dir():
92
72
  run_id += 1
93
73
  exp_dir.mkdir(exist_ok=True, parents=True)
94
74
  self._exp_dir = exp_dir.expanduser().resolve()
@@ -21,7 +21,7 @@ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
24
- CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
24
+ CheckpointPart = Literal["model", "opt", "opt_state", "state", "config", "model_state_config", "all"]
25
25
 
26
26
 
27
27
  def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
@@ -88,8 +88,16 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
88
88
  def load_checkpoint(
89
89
  self,
90
90
  path: Path,
91
+ part: Literal["all"] = "all",
91
92
  ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
92
93
 
94
+ @overload
95
+ def load_checkpoint(
96
+ self,
97
+ path: Path,
98
+ part: Literal["model_state_config"] = "model_state_config",
99
+ ) -> tuple[PyTree, State, DictConfig]: ...
100
+
93
101
  @overload
94
102
  def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
95
103
 
@@ -108,15 +116,19 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
108
116
  def load_checkpoint(
109
117
  self,
110
118
  path: Path,
111
- part: CheckpointPart | None = None,
119
+ part: CheckpointPart = "all",
112
120
  ) -> (
113
121
  tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
122
+ | tuple[PyTree, State, DictConfig]
114
123
  | PyTree
115
124
  | optax.GradientTransformation
116
125
  | optax.OptState
117
126
  | State
118
127
  | DictConfig
119
128
  ):
129
+ # Calls the base callback.
130
+ self.on_before_checkpoint_load(path)
131
+
120
132
  with tarfile.open(path, "r:gz") as tar:
121
133
 
122
134
  def get_model() -> PyTree:
@@ -155,7 +167,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
155
167
  return get_state()
156
168
  case "config":
157
169
  return get_config()
158
- case None:
170
+ case "model_state_config":
171
+ return get_model(), get_state(), get_config()
172
+ case "all":
159
173
  return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
160
174
  case _:
161
175
  raise ValueError(f"Invalid checkpoint part: {part}")
@@ -215,7 +229,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
215
229
  except FileExistsError:
216
230
  logger.exception("Exception while trying to update %s", ckpt_path)
217
231
 
218
- # Marks directory as having artifacts which shouldn't be overwritten.
219
- self.add_lock_file("ckpt", exists_ok=True)
232
+ # Calls the base callback.
233
+ self.on_after_checkpoint_save(ckpt_path, state)
220
234
 
221
235
  return ckpt_path
xax/task/mixins/logger.py CHANGED
@@ -8,6 +8,7 @@ from typing import Generic, Self, TypeVar
8
8
 
9
9
  import jax
10
10
 
11
+ from xax.core.conf import field
11
12
  from xax.core.state import State
12
13
  from xax.task.base import BaseConfig, BaseTask
13
14
  from xax.task.logger import Logger, LoggerImpl
@@ -22,7 +23,14 @@ from xax.utils.text import is_interactive_session
22
23
  @jax.tree_util.register_dataclass
23
24
  @dataclass
24
25
  class LoggerConfig(BaseConfig):
25
- pass
26
+ log_interval_seconds: float = field(
27
+ value=1.0,
28
+ help="The interval between successive log lines.",
29
+ )
30
+ tensorboard_log_interval_seconds: float = field(
31
+ value=10.0,
32
+ help="The interval between successive Tensorboard log lines.",
33
+ )
26
34
 
27
35
 
28
36
  Config = TypeVar("Config", bound=LoggerConfig)
@@ -49,11 +57,27 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
49
57
  self.logger.add_logger(*logger)
50
58
 
51
59
  def set_loggers(self) -> None:
52
- self.add_logger(StdoutLogger() if is_interactive_session() else JsonLogger())
60
+ self.add_logger(
61
+ StdoutLogger(
62
+ log_interval_seconds=self.config.log_interval_seconds,
63
+ )
64
+ if is_interactive_session()
65
+ else JsonLogger(
66
+ log_interval_seconds=self.config.log_interval_seconds,
67
+ )
68
+ )
69
+
70
+ # If this is also an ArtifactsMixin, we should default add some
71
+ # additional loggers which log data to the artifacts directory.
53
72
  if isinstance(self, ArtifactsMixin):
54
73
  self.add_logger(
55
- StateLogger(self.exp_dir),
56
- TensorboardLogger(self.exp_dir),
74
+ StateLogger(
75
+ run_directory=self.exp_dir,
76
+ ),
77
+ TensorboardLogger(
78
+ run_directory=self.exp_dir,
79
+ log_interval_seconds=self.config.tensorboard_log_interval_seconds,
80
+ ),
57
81
  )
58
82
 
59
83
  def write_logs(self, state: State) -> None:
@@ -1,53 +1,39 @@
1
1
  """Defines a mixin to wrap some steps in a context manager."""
2
2
 
3
+ import time
3
4
  from dataclasses import dataclass
4
5
  from types import TracebackType
5
- from typing import ContextManager, Literal, TypeVar
6
+ from typing import Callable, ContextManager, TypeVar
6
7
 
7
- import equinox as eqx
8
8
  import jax
9
9
 
10
10
  from xax.task.base import BaseConfig, BaseTask
11
11
 
12
- StepType = Literal[
13
- "backward",
14
- "change_mode",
15
- "clip_grads",
16
- "create_optimizers",
17
- "forward",
18
- "get_dataloader",
19
- "get_dataset",
20
- "get_prefetcher",
21
- "get_model",
22
- "get_optimizer",
23
- "get_initial_opt_state",
24
- "get_update_fn",
25
- "load_checkpoint",
26
- "log_losses",
27
- "model_to_device",
28
- "on_step_end",
29
- "on_step_start",
30
- "save_checkpoint",
31
- "step",
32
- "update_state",
33
- "write_logs",
34
- "zero_grads",
35
- ]
36
-
37
12
 
38
13
  class StepContext(ContextManager):
39
14
  """Context manager to get the current step type."""
40
15
 
41
- CURRENT_STEP: StepType | None = None
16
+ CURRENT_STEP: str | None = None
42
17
 
43
- def __init__(self, step: StepType) -> None:
18
+ def __init__(
19
+ self,
20
+ step: str,
21
+ on_context_start: Callable[[str], None],
22
+ on_context_end: Callable[[str, float], None],
23
+ ) -> None:
44
24
  self.step = step
25
+ self.start_time = 0.0
26
+ self.on_context_start = on_context_start
27
+ self.on_context_end = on_context_end
45
28
 
46
29
  def __enter__(self) -> None:
47
30
  StepContext.CURRENT_STEP = self.step
31
+ self.start_time = time.time()
32
+ self.on_context_start(self.step)
48
33
 
49
34
  def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
50
35
  StepContext.CURRENT_STEP = None
36
+ self.on_context_end(self.step, time.time() - self.start_time)
51
37
 
52
38
 
53
39
  @jax.tree_util.register_dataclass
@@ -63,6 +49,11 @@ class StepContextMixin(BaseTask[Config]):
63
49
  def __init__(self, config: Config) -> None:
64
50
  super().__init__(config)
65
51
 
66
- @eqx.filter_jit
67
- def step_context(self, step: StepType) -> ContextManager:
68
- return StepContext(step)
52
+ def step_context(self, step: str) -> ContextManager:
53
+ return StepContext(step, self.on_context_start, self.on_context_stop)
54
+
55
+ def on_context_start(self, step: str) -> None:
56
+ pass
57
+
58
+ def on_context_stop(self, step: str, elapsed_time: float) -> None:
59
+ pass
xax/task/mixins/train.py CHANGED
@@ -24,6 +24,7 @@ from typing import (
24
24
  TypeVar,
25
25
  cast,
26
26
  get_args,
27
+ overload,
27
28
  )
28
29
 
29
30
  import equinox as eqx
@@ -35,6 +36,7 @@ from omegaconf import DictConfig
35
36
 
36
37
  from xax.core.conf import field
37
38
  from xax.core.state import Phase, State
39
+ from xax.nn.functions import set_random_seed
38
40
  from xax.nn.parallel import is_master
39
41
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
40
42
  from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
@@ -115,7 +117,7 @@ class ValidStepTimer:
115
117
  if self.last_valid_time is None or self.last_valid_step is None:
116
118
  self.last_valid_time = state.elapsed_time_s
117
119
  self.last_valid_step = state.num_steps
118
- return True
120
+ return False
119
121
 
120
122
  # Step-based validation.
121
123
  valid_every_n_steps = self.valid_every_n_steps
@@ -183,6 +185,9 @@ class TrainMixin(
183
185
  def __init__(self, config: Config) -> None:
184
186
  super().__init__(config)
185
187
 
188
+ # Sets the random seed whenever we instantiate a new train mixin.
189
+ set_random_seed(self.config.random_seed)
190
+
186
191
  # Timer for validation steps.
187
192
  self.valid_step_timer = ValidStepTimer(
188
193
  valid_every_n_steps=config.valid_every_n_steps,
@@ -279,31 +284,53 @@ class TrainMixin(
279
284
  def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
280
285
  return optimizer.init(eqx.filter(model, eqx.is_array))
281
286
 
287
+ @overload
288
+ def load_initial_state(
289
+ self,
290
+ key: PRNGKeyArray,
291
+ load_optimizer: Literal[False] = False,
292
+ ) -> tuple[PyTree, State]: ...
293
+
294
+ @overload
282
295
  def load_initial_state(
283
296
  self,
284
297
  key: PRNGKeyArray,
285
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
298
+ load_optimizer: Literal[True],
299
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
300
+
301
+ def load_initial_state(
302
+ self,
303
+ key: PRNGKeyArray,
304
+ load_optimizer: bool = False,
305
+ ) -> tuple[PyTree, State] | tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
286
306
  init_ckpt_path = self.get_init_ckpt_path()
287
307
 
288
308
  if init_ckpt_path is not None:
289
309
  logger.info("Loading checkpoint from %s", init_ckpt_path)
290
- with self.step_context("load_checkpoint"):
310
+ if load_optimizer:
291
311
  model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
292
312
  config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
293
313
  if config_diff:
294
314
  logger.warning("Loaded config differs from current config:\n%s", config_diff)
295
315
  return model, optimizer, opt_state, state
296
316
 
297
- with self.step_context("get_model"):
298
- model = self.get_model(key)
317
+ else:
318
+ model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
319
+ config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
320
+ if config_diff:
321
+ logger.warning("Loaded config differs from current config:\n%s", config_diff)
322
+ return model, state
323
+
324
+ model = self.get_model(key)
325
+ state = State.init_state()
299
326
 
300
- with self.step_context("get_optimizer"):
301
- optimizer = self.get_optimizer()
327
+ if not load_optimizer:
328
+ return model, state
302
329
 
303
- with self.step_context("get_initial_opt_state"):
304
- opt_state = self.get_initial_opt_state(model, optimizer)
330
+ optimizer = self.get_optimizer()
331
+ opt_state = self.get_initial_opt_state(model, optimizer)
305
332
 
306
- return model, optimizer, opt_state, State.init_state()
333
+ return model, optimizer, opt_state, state
307
334
 
308
335
  @eqx.filter_jit
309
336
  def get_output(self, model: PyTree, batch: Batch) -> Output:
@@ -424,6 +451,7 @@ class TrainMixin(
424
451
  def log_state(self) -> None:
425
452
  logger.log(LOG_STATUS, self.task_path)
426
453
  logger.log(LOG_STATUS, self.task_name)
454
+ logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
427
455
  self.logger.log_file("git_state.txt", get_git_state(self))
428
456
  self.logger.log_file("training_code.txt", get_training_code(self))
429
457
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
@@ -456,7 +484,8 @@ class TrainMixin(
456
484
  while not self.is_training_over(state):
457
485
  if self.valid_step_timer.is_valid_step(state):
458
486
  valid_batch = next(valid_pf)
459
- model, loss, output = self.val_step(model, valid_batch)
487
+ with self.step_context("model_step"):
488
+ model, loss, output = self.val_step(model, valid_batch)
460
489
 
461
490
  # Perform logging.
462
491
  with self.step_context("write_logs"):
@@ -464,22 +493,19 @@ class TrainMixin(
464
493
  self.log_step(model, valid_batch, output, loss, state)
465
494
  state.num_valid_samples += 1
466
495
 
467
- with self.step_context("on_step_start"):
468
- state = self.on_step_start(state)
496
+ state = self.on_step_start(state)
469
497
 
470
- with self.step_context("update_state"):
498
+ with self.step_context("model_step"):
471
499
  train_batch = next(train_pf)
472
500
  model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
473
501
 
474
- # Perform logging.
475
502
  with self.step_context("write_logs"):
476
503
  state.phase = "train"
477
504
  self.log_step(model, train_batch, output, loss, state)
478
505
  state.num_steps += 1
479
506
  state.num_samples += self.get_size_of_batch(train_batch) or 0
480
507
 
481
- with self.step_context("on_step_end"):
482
- state = self.on_step_end(state)
508
+ state = self.on_step_end(state)
483
509
 
484
510
  if self.should_checkpoint(state):
485
511
  self.save_checkpoint(model, optimizer, opt_state, state)
@@ -496,14 +522,9 @@ class TrainMixin(
496
522
  except NotImplementedError:
497
523
  pass
498
524
 
499
- with self.step_context("get_dataset"):
500
- train_ds = self.get_dataset("train")
501
-
502
- with self.step_context("get_dataloader"):
503
- train_dl = self.get_dataloader(train_ds, "train")
504
-
505
- with self.step_context("get_prefetcher"):
506
- train_pf = self.get_prefetcher(train_dl)
525
+ train_ds = self.get_dataset("train")
526
+ train_dl = self.get_dataloader(train_ds, "train")
527
+ train_pf = self.get_prefetcher(train_dl)
507
528
 
508
529
  try:
509
530
  with train_pf as train_pf_ctx:
@@ -520,14 +541,9 @@ class TrainMixin(
520
541
  except NotImplementedError:
521
542
  pass
522
543
 
523
- with self.step_context("get_dataset"):
524
- valid_ds = self.get_dataset("valid")
525
-
526
- with self.step_context("get_dataloader"):
527
- valid_dl = self.get_dataloader(valid_ds, "valid")
528
-
529
- with self.step_context("get_prefetcher"):
530
- valid_pf = self.get_prefetcher(valid_dl)
544
+ valid_ds = self.get_dataset("valid")
545
+ valid_dl = self.get_dataloader(valid_ds, "valid")
546
+ valid_pf = self.get_prefetcher(valid_dl)
531
547
 
532
548
  try:
533
549
  with valid_pf as valid_pf_ctx:
@@ -559,7 +575,7 @@ class TrainMixin(
559
575
  Thread(target=self.log_state, daemon=True).start()
560
576
 
561
577
  key, model_key = jax.random.split(key)
562
- model, optimizer, opt_state, state = self.load_initial_state(model_key)
578
+ model, optimizer, opt_state, state = self.load_initial_state(model_key, load_optimizer=True)
563
579
  state = self.on_training_start(state)
564
580
 
565
581
  def on_exit() -> None: