xax 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,9 +6,9 @@ import logging
6
6
  import tarfile
7
7
  from dataclasses import asdict, dataclass
8
8
  from pathlib import Path
9
- from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
9
+ from typing import Generic, Literal, TypeVar, cast, overload
10
10
 
11
- import cloudpickle
11
+ import equinox as eqx
12
12
  import jax
13
13
  import optax
14
14
  from jaxtyping import PyTree
@@ -64,7 +64,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
64
64
  def get_init_ckpt_path(self) -> Path | None:
65
65
  if self._exp_dir is not None:
66
66
  ckpt_path = self.get_ckpt_path()
67
- if ckpt_path.exists():
67
+ if not ckpt_path.exists():
68
+ logger.warning("No checkpoint found in experiment directory: %s", ckpt_path)
69
+ else:
68
70
  return ckpt_path
69
71
  if self.config.load_from_ckpt_path is not None:
70
72
  ckpt_path = Path(self.config.load_from_ckpt_path)
@@ -87,41 +89,54 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
87
89
  def load_checkpoint(
88
90
  self,
89
91
  path: Path,
90
- part: Literal["all"] = "all",
91
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
92
+ *,
93
+ part: Literal["all"],
94
+ model_template: PyTree,
95
+ optimizer_template: PyTree,
96
+ opt_state_template: PyTree,
97
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
92
98
 
93
99
  @overload
94
100
  def load_checkpoint(
95
101
  self,
96
102
  path: Path,
97
- part: Literal["model_state_config"] = "model_state_config",
98
- ) -> tuple[PyTree, State, DictConfig]: ...
103
+ *,
104
+ part: Literal["model_state_config"],
105
+ model_template: PyTree,
106
+ ) -> tuple[PyTree, State, Config]: ...
99
107
 
100
108
  @overload
101
109
  def load_checkpoint(
102
110
  self,
103
111
  path: Path,
112
+ *,
104
113
  part: Literal["model"],
114
+ model_template: PyTree,
105
115
  ) -> PyTree: ...
106
116
 
107
117
  @overload
108
118
  def load_checkpoint(
109
119
  self,
110
120
  path: Path,
121
+ *,
111
122
  part: Literal["opt"],
123
+ optimizer_template: PyTree,
112
124
  ) -> optax.GradientTransformation: ...
113
125
 
114
126
  @overload
115
127
  def load_checkpoint(
116
128
  self,
117
129
  path: Path,
130
+ *,
118
131
  part: Literal["opt_state"],
132
+ opt_state_template: PyTree,
119
133
  ) -> optax.OptState: ...
120
134
 
121
135
  @overload
122
136
  def load_checkpoint(
123
137
  self,
124
138
  path: Path,
139
+ *,
125
140
  part: Literal["state"],
126
141
  ) -> State: ...
127
142
 
@@ -129,48 +144,71 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
129
144
  def load_checkpoint(
130
145
  self,
131
146
  path: Path,
147
+ *,
132
148
  part: Literal["config"],
133
- ) -> DictConfig: ...
149
+ ) -> Config: ...
134
150
 
135
151
  def load_checkpoint(
136
152
  self,
137
153
  path: Path,
154
+ *,
138
155
  part: CheckpointPart = "all",
156
+ model_template: PyTree | None = None,
157
+ optimizer_template: PyTree | None = None,
158
+ opt_state_template: PyTree | None = None,
139
159
  ) -> (
140
- tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
141
- | tuple[PyTree, State, DictConfig]
160
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
161
+ | tuple[PyTree, State, Config]
142
162
  | PyTree
143
163
  | optax.GradientTransformation
144
164
  | optax.OptState
145
165
  | State
146
- | DictConfig
166
+ | Config
147
167
  ):
168
+ """Load a checkpoint.
169
+
170
+ Args:
171
+ path: Path to the checkpoint directory
172
+ part: Which part of the checkpoint to load
173
+ model_template: Template model with correct structure but uninitialized weights
174
+ optimizer_template: Template optimizer with correct structure but uninitialized weights
175
+ opt_state_template: Template optimizer state with correct structure but uninitialized weights
176
+
177
+ Returns:
178
+ The requested checkpoint components
179
+ """
148
180
  with tarfile.open(path, "r:gz") as tar:
149
181
 
150
182
  def get_model() -> PyTree:
183
+ if model_template is None:
184
+ raise ValueError("model_template must be provided to load model weights")
151
185
  if (model := tar.extractfile("model")) is None:
152
186
  raise ValueError(f"Checkpoint does not contain a model file: {path}")
153
- return cloudpickle.load(model)
187
+ return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
154
188
 
155
189
  def get_opt() -> optax.GradientTransformation:
156
- if (opt := tar.extractfile("opt")) is None:
157
- raise ValueError(f"Checkpoint does not contain an opt file: {path}")
158
- return cloudpickle.load(opt)
190
+ if optimizer_template is None:
191
+ raise ValueError("optimizer_template must be provided to load optimizer")
192
+ if (opt := tar.extractfile("optimizer")) is None:
193
+ raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
194
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
159
195
 
160
196
  def get_opt_state() -> optax.OptState:
197
+ if opt_state_template is None:
198
+ raise ValueError("opt_state_template must be provided to load optimizer state")
161
199
  if (opt_state := tar.extractfile("opt_state")) is None:
162
- raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
163
- return cloudpickle.load(opt_state)
200
+ raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
201
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
164
202
 
165
203
  def get_state() -> State:
166
204
  if (state := tar.extractfile("state")) is None:
167
205
  raise ValueError(f"Checkpoint does not contain a state file: {path}")
168
206
  return State(**json.loads(state.read().decode()))
169
207
 
170
- def get_config() -> DictConfig:
208
+ def get_config() -> Config:
171
209
  if (config := tar.extractfile("config")) is None:
172
210
  raise ValueError(f"Checkpoint does not contain a config file: {path}")
173
- return cast(DictConfig, OmegaConf.load(config))
211
+ return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
174
212
 
175
213
  match part:
176
214
  case "model":
@@ -192,51 +230,90 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
192
230
 
193
231
  def save_checkpoint(
194
232
  self,
195
- model: PyTree,
196
- optimizer: optax.GradientTransformation,
197
- opt_state: optax.OptState,
198
- state: State,
233
+ model: PyTree | None = None,
234
+ optimizer: optax.GradientTransformation | None = None,
235
+ opt_state: optax.OptState | None = None,
236
+ aux_data: PyTree | None = None,
237
+ state: State | None = None,
199
238
  ) -> Path:
239
+ """Save a checkpoint.
240
+
241
+ Args:
242
+ model: The model to save
243
+ state: The current training state
244
+ optimizer: The optimizer to save
245
+ aux_data: Additional data to save
246
+ opt_state: The optimizer state to save
247
+
248
+ Returns:
249
+ Path to the saved checkpoint
250
+ """
200
251
  ckpt_path = self.get_ckpt_path(state)
201
252
 
202
253
  if not is_master():
203
254
  return ckpt_path
204
255
 
205
- # Gets the path to the last checkpoint.
256
+ # Gets the path to the last checkpoint
206
257
  logger.info("Saving checkpoint to %s", ckpt_path)
207
258
  last_ckpt_path = self.get_ckpt_path()
208
259
  ckpt_path.parent.mkdir(exist_ok=True, parents=True)
209
260
 
210
- # Potentially removes the last checkpoint.
261
+ # Potentially removes the last checkpoint
211
262
  if last_ckpt_path.exists() and self.config.only_save_most_recent:
212
263
  if (base_ckpt := last_ckpt_path.resolve()).is_file():
213
264
  base_ckpt.unlink()
214
265
 
215
- # Combines all temporary files into a single checkpoint TAR file.
266
+ # Save the checkpoint components
216
267
  with tarfile.open(ckpt_path, "w:gz") as tar:
217
268
 
218
- def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
269
+ def add_file(name: str, buf: io.BytesIO) -> None:
270
+ tarinfo = tarfile.TarInfo(name)
271
+ tarinfo.size = buf.tell()
272
+ buf.seek(0)
273
+ tar.addfile(tarinfo, buf)
274
+
275
+ # Save model using Equinox
276
+ if model is not None:
277
+ with io.BytesIO() as buf:
278
+ eqx.tree_serialise_leaves(buf, model)
279
+ add_file("model", buf)
280
+
281
+ # Save optimizer using Equinox
282
+ if optimizer is not None:
283
+ with io.BytesIO() as buf:
284
+ eqx.tree_serialise_leaves(buf, optimizer)
285
+ add_file("optimizer", buf)
286
+
287
+ # Save optimizer state using Equinox
288
+ if opt_state is not None:
219
289
  with io.BytesIO() as buf:
220
- write_fn(buf)
221
- tarinfo = tarfile.TarInfo(name)
222
- tarinfo.size = buf.tell()
223
- buf.seek(0)
224
- tar.addfile(tarinfo, buf)
225
-
226
- add_file("model", lambda buf: cloudpickle.dump(model, buf))
227
- add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
228
- add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
229
- add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
230
- add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
231
-
232
- # Updates the symlink to the new checkpoint.
290
+ eqx.tree_serialise_leaves(buf, opt_state)
291
+ add_file("opt_state", buf)
292
+
293
+ # Save aux data using Equinox.
294
+ if aux_data is not None:
295
+ with io.BytesIO() as buf:
296
+ eqx.tree_serialise_leaves(buf, aux_data)
297
+ add_file("aux_data", buf)
298
+
299
+ # Save state and config as JSON
300
+ def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
301
+ info = tarfile.TarInfo(name=name)
302
+ info.size = len(data)
303
+ tar.addfile(info, io.BytesIO(data))
304
+
305
+ if state is not None:
306
+ add_file_bytes("state", json.dumps(asdict(state), indent=2).encode())
307
+ add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
308
+
309
+ # Updates the symlink to the new checkpoint
233
310
  last_ckpt_path.unlink(missing_ok=True)
234
311
  try:
235
312
  last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
236
313
  except FileExistsError:
237
314
  logger.exception("Exception while trying to update %s", ckpt_path)
238
315
 
239
- # Calls the base callback.
316
+ # Calls the base callback
240
317
  self.on_after_checkpoint_save(ckpt_path, state)
241
318
 
242
319
  return ckpt_path
@@ -248,15 +248,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
248
248
  stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
249
249
 
250
250
  if stats is not None:
251
- self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
252
- self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
253
- self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
254
- self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
255
- self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
256
- self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
257
- self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
258
- self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
259
- self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
260
- self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
251
+ self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu", secondary=True)
252
+ self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu", secondary=True)
253
+ self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu", secondary=True)
254
+ self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem", secondary=True)
255
+ self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem", secondary=True)
256
+ self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem", secondary=True)
257
+ self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem", secondary=True)
258
+ self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem", secondary=True)
259
+ self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem", secondary=True)
260
+ self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem", secondary=True)
261
261
 
262
262
  return state
@@ -9,6 +9,7 @@ import jax
9
9
  from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
10
10
  from dpshdl.dataset import Dataset, ErrorHandlingDataset
11
11
  from dpshdl.prefetcher import Prefetcher
12
+ from jaxtyping import PRNGKeyArray
12
13
  from omegaconf import II, MISSING
13
14
 
14
15
  from xax.core.conf import field, is_missing
@@ -103,7 +104,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
103
104
  "or `get_data_iterator` to return an iterator for the given dataset."
104
105
  )
105
106
 
106
- def get_data_iterator(self, phase: Phase) -> Iterator:
107
+ def get_data_iterator(self, phase: Phase, key: PRNGKeyArray) -> Iterator:
107
108
  raise NotImplementedError(
108
109
  "You must implement either the `get_dataset` method to return the dataset for the given phase, "
109
110
  "or `get_data_iterator` to return an iterator for the given dataset."
@@ -264,8 +264,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
264
264
  for gpu_stat in stats.values():
265
265
  if gpu_stat is None:
266
266
  continue
267
- self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
268
- self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
269
- self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
267
+ self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu", secondary=True)
268
+ self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu", secondary=True)
269
+ self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu", secondary=True)
270
270
 
271
271
  return state
xax/task/mixins/train.py CHANGED
@@ -11,7 +11,7 @@ import textwrap
11
11
  import time
12
12
  import traceback
13
13
  from abc import ABC, abstractmethod
14
- from dataclasses import dataclass, is_dataclass
14
+ from dataclasses import asdict, dataclass, is_dataclass
15
15
  from threading import Thread
16
16
  from typing import (
17
17
  Any,
@@ -33,7 +33,6 @@ import jax.numpy as jnp
33
33
  import numpy as np
34
34
  import optax
35
35
  from jaxtyping import Array, PRNGKeyArray, PyTree
36
- from omegaconf import DictConfig
37
36
 
38
37
  from xax.core.conf import field
39
38
  from xax.core.state import Phase, State
@@ -50,6 +49,7 @@ from xax.utils.experiments import (
50
49
  TrainingFinishedError,
51
50
  diff_configs,
52
51
  get_diff_string,
52
+ get_info_json,
53
53
  get_state_file_string,
54
54
  get_training_code,
55
55
  )
@@ -218,7 +218,12 @@ class TrainMixin(
218
218
  return state.replace(elapsed_time_s=time.time() - state.start_time_s)
219
219
 
220
220
  def log_train_step(
221
- self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
221
+ self,
222
+ model: PyTree,
223
+ batch: Batch,
224
+ output: Output,
225
+ metrics: FrozenDict[str, Array],
226
+ state: State,
222
227
  ) -> None:
223
228
  """Override this function to do logging during the training phase.
224
229
 
@@ -234,7 +239,12 @@ class TrainMixin(
234
239
  """
235
240
 
236
241
  def log_valid_step(
237
- self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
242
+ self,
243
+ model: PyTree,
244
+ batch: Batch,
245
+ output: Output,
246
+ metrics: FrozenDict[str, Array],
247
+ state: State,
238
248
  ) -> None:
239
249
  """Override this function to do logging during the validation phase.
240
250
 
@@ -252,12 +262,20 @@ class TrainMixin(
252
262
  def log_state_timers(self, state: State) -> None:
253
263
  timer = self.state_timers[state.phase]
254
264
  timer.step(state)
255
- for ns, d in timer.log_dict().items():
256
- for k, v in d.items():
257
- self.logger.log_scalar(k, v, namespace=ns)
265
+ for k, v in timer.log_dict().items():
266
+ if isinstance(v, tuple):
267
+ v, secondary = v
268
+ else:
269
+ secondary = False
270
+ self.logger.log_scalar(k, v, namespace="⌛ timers", secondary=secondary)
258
271
 
259
272
  def log_step(
260
- self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
273
+ self,
274
+ model: PyTree,
275
+ batch: Batch,
276
+ output: Output,
277
+ metrics: FrozenDict[str, Array],
278
+ state: State,
261
279
  ) -> None:
262
280
  phase = state.phase
263
281
 
@@ -322,20 +340,30 @@ class TrainMixin(
322
340
 
323
341
  if init_ckpt_path is not None:
324
342
  logger.info("Loading checkpoint from %s", init_ckpt_path)
325
- if load_optimizer:
326
- model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
327
- config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
328
- if config_diff:
329
- logger.warning("Loaded config differs from current config:\n%s", config_diff)
330
- return model, optimizer, opt_state, state
343
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
344
+ model, state, config = self.load_checkpoint(
345
+ init_ckpt_path,
346
+ part="model_state_config",
347
+ model_template=model_spec,
348
+ )
349
+ config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
350
+ if config_diff:
351
+ logger.warning("Loaded config differs from current config:\n%s", config_diff)
331
352
 
332
- else:
333
- model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
334
- config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
335
- if config_diff:
336
- logger.warning("Loaded config differs from current config:\n%s", config_diff)
353
+ if not load_optimizer:
337
354
  return model, state
338
355
 
356
+ # Loads the optimizer.
357
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
358
+ optimizer = self.load_checkpoint(init_ckpt_path, part="opt", optimizer_template=optimizer_spec)
359
+
360
+ # Loads the optimizer state.
361
+ opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
362
+ opt_state = self.load_checkpoint(init_ckpt_path, part="opt_state", opt_state_template=opt_state_spec)
363
+
364
+ return model, optimizer, opt_state, state
365
+
366
+ logger.info("No checkpoint found. Initializing a new model.")
339
367
  model = self.get_model(key)
340
368
  state = State.init_state()
341
369
 
@@ -536,6 +564,7 @@ class TrainMixin(
536
564
  self.logger.log_file("state.txt", get_state_file_string(self))
537
565
  self.logger.log_file("training_code.py", get_training_code(self))
538
566
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
567
+ self.logger.log_file("info.json", get_info_json())
539
568
 
540
569
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
541
570
  return eqx.is_inexact_array(item)
@@ -609,16 +638,16 @@ class TrainMixin(
609
638
 
610
639
  if self.should_checkpoint(state):
611
640
  model = eqx.combine(model_arr, model_static)
612
- self.save_checkpoint(model, optimizer, opt_state, state)
641
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
613
642
 
614
643
  # After finishing training, save the final checkpoint.
615
644
  model = eqx.combine(model_arr, model_static)
616
- self.save_checkpoint(model, optimizer, opt_state, state)
645
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
617
646
 
618
647
  @contextlib.contextmanager
619
- def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
648
+ def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
620
649
  try:
621
- train_iterator: Iterator[Batch] = self.get_data_iterator("train")
650
+ train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
622
651
  yield train_iterator
623
652
  return
624
653
  except NotImplementedError:
@@ -635,9 +664,9 @@ class TrainMixin(
635
664
  logger.info("Closing train prefetcher")
636
665
 
637
666
  @contextlib.contextmanager
638
- def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
667
+ def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
639
668
  try:
640
- valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
669
+ valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
641
670
  yield valid_iterator
642
671
  return
643
672
  except NotImplementedError:
@@ -681,12 +710,13 @@ class TrainMixin(
681
710
  state = self.on_training_start(state)
682
711
 
683
712
  def on_exit() -> None:
684
- self.save_checkpoint(model, optimizer, opt_state, state)
713
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
685
714
 
686
715
  # Handle user-defined interrupts during the training loop.
687
716
  self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
688
717
 
689
- with self.get_train_iterator() as train_pf, self.get_valid_iterator() as valid_pf:
718
+ key, tkey, vkey = jax.random.split(key, 3)
719
+ with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
690
720
  try:
691
721
  self.train_loop(
692
722
  model=model,
@@ -703,7 +733,7 @@ class TrainMixin(
703
733
  f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
704
734
  important=True,
705
735
  )
706
- self.save_checkpoint(model, optimizer, opt_state, state)
736
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
707
737
 
708
738
  except (KeyboardInterrupt, bdb.BdbQuit):
709
739
  if is_master():
@@ -713,7 +743,7 @@ class TrainMixin(
713
743
  exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
714
744
  sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
715
745
  sys.stdout.flush()
716
- self.save_checkpoint(model, optimizer, opt_state, state)
746
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
717
747
 
718
748
  finally:
719
749
  state = self.on_training_end(state)
xax/utils/experiments.py CHANGED
@@ -7,6 +7,7 @@ import functools
7
7
  import hashlib
8
8
  import inspect
9
9
  import itertools
10
+ import json
10
11
  import logging
11
12
  import math
12
13
  import os
@@ -24,7 +25,7 @@ import warnings
24
25
  from abc import ABC, abstractmethod
25
26
  from pathlib import Path
26
27
  from types import TracebackType
27
- from typing import Any, Iterator, Self, TypeVar, cast
28
+ from typing import Any, Iterator, Mapping, Self, Sequence, TypeVar, cast
28
29
  from urllib.parse import urlparse
29
30
 
30
31
  import git
@@ -114,28 +115,13 @@ class StateTimer:
114
115
  self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
115
116
  self.iter_timer.step(cur_time)
116
117
 
117
- def log_dict(self) -> dict[str, dict[str, int | float]]:
118
- logs: dict[str, dict[str, int | float]] = {}
119
-
120
- # Logs step statistics.
121
- logs["⌛ steps"] = {
122
- "total": self.step_timer.steps,
123
- "per-second": self.step_timer.steps_per_second,
124
- }
125
-
126
- # Logs sample statistics.
127
- logs["⌛ samples"] = {
128
- "total": self.sample_timer.steps,
129
- "per-second": self.sample_timer.steps_per_second,
118
+ def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
119
+ return {
120
+ "steps/second": self.step_timer.steps_per_second,
121
+ "samples/second": (self.sample_timer.steps_per_second, True),
122
+ "dt": self.iter_timer.iter_seconds,
130
123
  }
131
124
 
132
- # Logs full iteration statistics.
133
- logs["⌛ dt"] = {
134
- "iter": self.iter_timer.iter_seconds,
135
- }
136
-
137
- return logs
138
-
139
125
 
140
126
  class IntervalTicker:
141
127
  def __init__(self, interval: float) -> None:
@@ -217,8 +203,8 @@ class MinGradScaleError(TrainingFinishedError):
217
203
 
218
204
 
219
205
  def diff_configs(
220
- first: ListConfig | DictConfig,
221
- second: ListConfig | DictConfig,
206
+ first: Mapping | Sequence,
207
+ second: Mapping | Sequence,
222
208
  prefix: str | None = None,
223
209
  ) -> tuple[list[str], list[str]]:
224
210
  """Returns the difference between two configs.
@@ -245,7 +231,7 @@ def diff_configs(
245
231
 
246
232
  any_config = (ListConfig, DictConfig)
247
233
 
248
- if isinstance(first, DictConfig) and isinstance(second, DictConfig):
234
+ if isinstance(first, Mapping) and isinstance(second, Mapping):
249
235
  first_keys, second_keys = cast(set[str], set(first.keys())), cast(set[str], set(second.keys()))
250
236
 
251
237
  # Gets the new keys in each config.
@@ -255,11 +241,12 @@ def diff_configs(
255
241
  # Gets the new sub-keys in each config.
256
242
  for key in first_keys.intersection(second_keys):
257
243
  sub_prefix = key if prefix is None else f"{prefix}.{key}"
258
- if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
259
- if not OmegaConf.is_missing(first, key):
260
- new_first += [get_diff_string(sub_prefix, first[key])]
261
- if not OmegaConf.is_missing(second, key):
262
- new_second += [get_diff_string(sub_prefix, second[key])]
244
+ if isinstance(first, DictConfig) and isinstance(second, DictConfig):
245
+ if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
246
+ if not OmegaConf.is_missing(first, key):
247
+ new_first += [get_diff_string(sub_prefix, first[key])]
248
+ if not OmegaConf.is_missing(second, key):
249
+ new_second += [get_diff_string(sub_prefix, second[key])]
263
250
  elif isinstance(first[key], any_config) and isinstance(second[key], any_config):
264
251
  sub_new_first, sub_new_second = diff_configs(first[key], second[key], prefix=sub_prefix)
265
252
  new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
@@ -268,7 +255,7 @@ def diff_configs(
268
255
  new_first += [get_diff_string(sub_prefix, first_val)]
269
256
  new_second += [get_diff_string(sub_prefix, second_val)]
270
257
 
271
- elif isinstance(first, ListConfig) and isinstance(second, ListConfig):
258
+ elif isinstance(first, Sequence) and isinstance(second, Sequence):
272
259
  if len(first) > len(second):
273
260
  for i in range(len(second), len(first)):
274
261
  new_first += [get_diff_string(prefix, first[i])]
@@ -483,16 +470,33 @@ def get_command_line_string() -> str:
483
470
  return " ".join(sys.argv)
484
471
 
485
472
 
473
+ def get_environment_variables() -> str:
474
+ return "\n".join([f"{key}={value}" for key, value in sorted(os.environ.items())])
475
+
476
+
486
477
  def get_state_file_string(obj: object) -> str:
487
478
  return "\n\n".join(
488
479
  [
489
480
  f"=== Command Line ===\n\n{get_command_line_string()}",
490
481
  f"=== Git State ===\n\n{get_git_state(obj)}",
491
482
  f"=== Packages ===\n\n{get_packages_with_versions()}",
483
+ f"=== Environment Variables ===\n\n{get_environment_variables()}",
492
484
  ]
493
485
  )
494
486
 
495
487
 
488
+ def get_info_json() -> str:
489
+ return json.dumps(
490
+ {
491
+ "process_id": os.getpid(),
492
+ "job": {
493
+ "start_time": datetime.datetime.now().isoformat(),
494
+ },
495
+ },
496
+ indent=2,
497
+ )
498
+
499
+
496
500
  def get_training_code(obj: object) -> str:
497
501
  """Gets the text from the file containing the provided object.
498
502