xax 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,53 +1,39 @@
1
1
  """Defines a mixin to wrap some steps in a context manager."""
2
2
 
3
+ import time
3
4
  from dataclasses import dataclass
4
5
  from types import TracebackType
5
- from typing import ContextManager, Literal, TypeVar
6
+ from typing import Callable, ContextManager, TypeVar
6
7
 
7
- import equinox as eqx
8
8
  import jax
9
9
 
10
10
  from xax.task.base import BaseConfig, BaseTask
11
11
 
12
- StepType = Literal[
13
- "backward",
14
- "change_mode",
15
- "clip_grads",
16
- "create_optimizers",
17
- "forward",
18
- "get_dataloader",
19
- "get_dataset",
20
- "get_prefetcher",
21
- "get_model",
22
- "get_optimizer",
23
- "get_initial_opt_state",
24
- "get_update_fn",
25
- "load_checkpoint",
26
- "log_losses",
27
- "model_to_device",
28
- "on_step_end",
29
- "on_step_start",
30
- "save_checkpoint",
31
- "step",
32
- "update_state",
33
- "write_logs",
34
- "zero_grads",
35
- ]
36
-
37
12
 
38
13
  class StepContext(ContextManager):
39
14
  """Context manager to get the current step type."""
40
15
 
41
- CURRENT_STEP: StepType | None = None
16
+ CURRENT_STEP: str | None = None
42
17
 
43
- def __init__(self, step: StepType) -> None:
18
+ def __init__(
19
+ self,
20
+ step: str,
21
+ on_context_start: Callable[[str], None],
22
+ on_context_end: Callable[[str, float], None],
23
+ ) -> None:
44
24
  self.step = step
25
+ self.start_time = 0.0
26
+ self.on_context_start = on_context_start
27
+ self.on_context_end = on_context_end
45
28
 
46
29
  def __enter__(self) -> None:
47
30
  StepContext.CURRENT_STEP = self.step
31
+ self.start_time = time.time()
32
+ self.on_context_start(self.step)
48
33
 
49
34
  def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
50
35
  StepContext.CURRENT_STEP = None
36
+ self.on_context_end(self.step, time.time() - self.start_time)
51
37
 
52
38
 
53
39
  @jax.tree_util.register_dataclass
@@ -63,6 +49,11 @@ class StepContextMixin(BaseTask[Config]):
63
49
  def __init__(self, config: Config) -> None:
64
50
  super().__init__(config)
65
51
 
66
- @eqx.filter_jit
67
- def step_context(self, step: StepType) -> ContextManager:
68
- return StepContext(step)
52
+ def step_context(self, step: str) -> ContextManager:
53
+ return StepContext(step, self.on_context_start, self.on_context_stop)
54
+
55
+ def on_context_start(self, step: str) -> None:
56
+ pass
57
+
58
+ def on_context_stop(self, step: str, elapsed_time: float) -> None:
59
+ pass
xax/task/mixins/train.py CHANGED
@@ -24,6 +24,7 @@ from typing import (
24
24
  TypeVar,
25
25
  cast,
26
26
  get_args,
27
+ overload,
27
28
  )
28
29
 
29
30
  import equinox as eqx
@@ -35,6 +36,7 @@ from omegaconf import DictConfig
35
36
 
36
37
  from xax.core.conf import field
37
38
  from xax.core.state import Phase, State
39
+ from xax.nn.functions import set_random_seed
38
40
  from xax.nn.parallel import is_master
39
41
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
40
42
  from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
@@ -115,7 +117,7 @@ class ValidStepTimer:
115
117
  if self.last_valid_time is None or self.last_valid_step is None:
116
118
  self.last_valid_time = state.elapsed_time_s
117
119
  self.last_valid_step = state.num_steps
118
- return True
120
+ return False
119
121
 
120
122
  # Step-based validation.
121
123
  valid_every_n_steps = self.valid_every_n_steps
@@ -154,7 +156,6 @@ class TrainConfig(
154
156
  valid_first_n_steps: int = field(0, help="Treat the first N steps as validation steps")
155
157
  valid_every_n_seconds: float | None = field(60.0 * 10.0, help="Run validation every N seconds")
156
158
  valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
157
- batch_dim: int = field(0, help="The batch dimension, for splitting batches into chunks")
158
159
  max_steps: int | None = field(None, help="Maximum number of steps to run")
159
160
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
160
161
  random_seed: int = field(1337, help="Random seed for the task")
@@ -183,6 +184,9 @@ class TrainMixin(
183
184
  def __init__(self, config: Config) -> None:
184
185
  super().__init__(config)
185
186
 
187
+ # Sets the random seed whenever we instantiate a new train mixin.
188
+ set_random_seed(self.config.random_seed)
189
+
186
190
  # Timer for validation steps.
187
191
  self.valid_step_timer = ValidStepTimer(
188
192
  valid_every_n_steps=config.valid_every_n_steps,
@@ -279,31 +283,53 @@ class TrainMixin(
279
283
  def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
280
284
  return optimizer.init(eqx.filter(model, eqx.is_array))
281
285
 
286
+ @overload
287
+ def load_initial_state(
288
+ self,
289
+ key: PRNGKeyArray,
290
+ load_optimizer: Literal[False] = False,
291
+ ) -> tuple[PyTree, State]: ...
292
+
293
+ @overload
282
294
  def load_initial_state(
283
295
  self,
284
296
  key: PRNGKeyArray,
285
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
297
+ load_optimizer: Literal[True],
298
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
299
+
300
+ def load_initial_state(
301
+ self,
302
+ key: PRNGKeyArray,
303
+ load_optimizer: bool = False,
304
+ ) -> tuple[PyTree, State] | tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
286
305
  init_ckpt_path = self.get_init_ckpt_path()
287
306
 
288
307
  if init_ckpt_path is not None:
289
308
  logger.info("Loading checkpoint from %s", init_ckpt_path)
290
- with self.step_context("load_checkpoint"):
309
+ if load_optimizer:
291
310
  model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
292
311
  config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
293
312
  if config_diff:
294
313
  logger.warning("Loaded config differs from current config:\n%s", config_diff)
295
314
  return model, optimizer, opt_state, state
296
315
 
297
- with self.step_context("get_model"):
298
- model = self.get_model(key)
316
+ else:
317
+ model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
318
+ config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
319
+ if config_diff:
320
+ logger.warning("Loaded config differs from current config:\n%s", config_diff)
321
+ return model, state
322
+
323
+ model = self.get_model(key)
324
+ state = State.init_state()
299
325
 
300
- with self.step_context("get_optimizer"):
301
- optimizer = self.get_optimizer()
326
+ if not load_optimizer:
327
+ return model, state
302
328
 
303
- with self.step_context("get_initial_opt_state"):
304
- opt_state = self.get_initial_opt_state(model, optimizer)
329
+ optimizer = self.get_optimizer()
330
+ opt_state = self.get_initial_opt_state(model, optimizer)
305
331
 
306
- return model, optimizer, opt_state, State.init_state()
332
+ return model, optimizer, opt_state, state
307
333
 
308
334
  @eqx.filter_jit
309
335
  def get_output(self, model: PyTree, batch: Batch) -> Output:
@@ -424,6 +450,7 @@ class TrainMixin(
424
450
  def log_state(self) -> None:
425
451
  logger.log(LOG_STATUS, self.task_path)
426
452
  logger.log(LOG_STATUS, self.task_name)
453
+ logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
427
454
  self.logger.log_file("git_state.txt", get_git_state(self))
428
455
  self.logger.log_file("training_code.txt", get_training_code(self))
429
456
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
@@ -456,7 +483,8 @@ class TrainMixin(
456
483
  while not self.is_training_over(state):
457
484
  if self.valid_step_timer.is_valid_step(state):
458
485
  valid_batch = next(valid_pf)
459
- model, loss, output = self.val_step(model, valid_batch)
486
+ with self.step_context("model_step"):
487
+ model, loss, output = self.val_step(model, valid_batch)
460
488
 
461
489
  # Perform logging.
462
490
  with self.step_context("write_logs"):
@@ -464,22 +492,19 @@ class TrainMixin(
464
492
  self.log_step(model, valid_batch, output, loss, state)
465
493
  state.num_valid_samples += 1
466
494
 
467
- with self.step_context("on_step_start"):
468
- state = self.on_step_start(state)
495
+ state = self.on_step_start(state)
469
496
 
470
- with self.step_context("update_state"):
497
+ with self.step_context("model_step"):
471
498
  train_batch = next(train_pf)
472
499
  model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
473
500
 
474
- # Perform logging.
475
501
  with self.step_context("write_logs"):
476
502
  state.phase = "train"
477
503
  self.log_step(model, train_batch, output, loss, state)
478
504
  state.num_steps += 1
479
505
  state.num_samples += self.get_size_of_batch(train_batch) or 0
480
506
 
481
- with self.step_context("on_step_end"):
482
- state = self.on_step_end(state)
507
+ state = self.on_step_end(state)
483
508
 
484
509
  if self.should_checkpoint(state):
485
510
  self.save_checkpoint(model, optimizer, opt_state, state)
@@ -496,14 +521,9 @@ class TrainMixin(
496
521
  except NotImplementedError:
497
522
  pass
498
523
 
499
- with self.step_context("get_dataset"):
500
- train_ds = self.get_dataset("train")
501
-
502
- with self.step_context("get_dataloader"):
503
- train_dl = self.get_dataloader(train_ds, "train")
504
-
505
- with self.step_context("get_prefetcher"):
506
- train_pf = self.get_prefetcher(train_dl)
524
+ train_ds = self.get_dataset("train")
525
+ train_dl = self.get_dataloader(train_ds, "train")
526
+ train_pf = self.get_prefetcher(train_dl)
507
527
 
508
528
  try:
509
529
  with train_pf as train_pf_ctx:
@@ -520,14 +540,9 @@ class TrainMixin(
520
540
  except NotImplementedError:
521
541
  pass
522
542
 
523
- with self.step_context("get_dataset"):
524
- valid_ds = self.get_dataset("valid")
525
-
526
- with self.step_context("get_dataloader"):
527
- valid_dl = self.get_dataloader(valid_ds, "valid")
528
-
529
- with self.step_context("get_prefetcher"):
530
- valid_pf = self.get_prefetcher(valid_dl)
543
+ valid_ds = self.get_dataset("valid")
544
+ valid_dl = self.get_dataloader(valid_ds, "valid")
545
+ valid_pf = self.get_prefetcher(valid_dl)
531
546
 
532
547
  try:
533
548
  with valid_pf as valid_pf_ctx:
@@ -559,7 +574,7 @@ class TrainMixin(
559
574
  Thread(target=self.log_state, daemon=True).start()
560
575
 
561
576
  key, model_key = jax.random.split(key)
562
- model, optimizer, opt_state, state = self.load_initial_state(model_key)
577
+ model, optimizer, opt_state, state = self.load_initial_state(model_key, load_optimizer=True)
563
578
  state = self.on_training_start(state)
564
579
 
565
580
  def on_exit() -> None:
xax/task/script.py CHANGED
@@ -17,8 +17,6 @@ from xax.task.mixins import (
17
17
  ProcessMixin,
18
18
  RunnableConfig,
19
19
  RunnableMixin,
20
- StepContextConfig,
21
- StepContextMixin,
22
20
  )
23
21
 
24
22
 
@@ -28,7 +26,6 @@ class ScriptConfig(
28
26
  GPUStatsConfig,
29
27
  ProcessConfig,
30
28
  LoggerConfig,
31
- StepContextConfig,
32
29
  ArtifactsConfig,
33
30
  RunnableConfig,
34
31
  BaseConfig,
@@ -44,7 +41,6 @@ class Script(
44
41
  GPUStatsMixin[ConfigT],
45
42
  ProcessMixin[ConfigT],
46
43
  LoggerMixin[ConfigT],
47
- StepContextMixin[ConfigT],
48
44
  ArtifactsMixin[ConfigT],
49
45
  RunnableMixin[ConfigT],
50
46
  BaseTask[ConfigT],
xax/utils/debugging.py ADDED
@@ -0,0 +1,49 @@
1
+ """Defines some useful Jax debugging utilities."""
2
+
3
+ from collections import deque
4
+ from collections.abc import Iterable, Mapping
5
+ from typing import Any, Callable, Deque
6
+
7
+ from jaxtyping import Array
8
+
9
+
10
+ def get_named_leaves(
11
+ obj: Any, # noqa: ANN401
12
+ is_leaf: Callable[[Any], bool] = lambda x: isinstance(x, Array), # noqa: ANN401
13
+ max_depth: int = 100,
14
+ ) -> list[tuple[str, Any]]: # noqa: ANN401
15
+ ret: list[tuple[str, Any]] = []
16
+ q: Deque[tuple[int, str, Any]] = deque() # noqa: ANN401
17
+ q.append((0, "", obj))
18
+
19
+ while q:
20
+ depth, name, node = q.popleft()
21
+
22
+ if depth > max_depth:
23
+ continue
24
+
25
+ if hasattr(node, "__dict__") and isinstance(node.__dict__, Mapping):
26
+ for cname, cnode in node.__dict__.items():
27
+ gname = f"{name}.{cname}" if name else cname
28
+ if is_leaf(cnode):
29
+ ret.append((gname, cnode))
30
+ else:
31
+ q.append((depth + 1, gname, cnode))
32
+
33
+ elif isinstance(node, Mapping):
34
+ for cname, cnode in node.items():
35
+ gname = f"{name}.{cname}" if name else cname
36
+ if is_leaf(cnode):
37
+ ret.append((gname, cnode))
38
+ else:
39
+ q.append((depth + 1, gname, cnode))
40
+
41
+ elif isinstance(node, Iterable):
42
+ for i, cnode in enumerate(node):
43
+ gname = f"{name}.{i}" if name else str(i)
44
+ if is_leaf(cnode):
45
+ ret.append((gname, cnode))
46
+ else:
47
+ q.append((depth + 1, gname, cnode))
48
+
49
+ return ret
xax/utils/experiments.py CHANGED
@@ -23,7 +23,8 @@ import urllib.request
23
23
  import warnings
24
24
  from abc import ABC, abstractmethod
25
25
  from pathlib import Path
26
- from typing import Any, Iterator, TypeVar, cast
26
+ from types import TracebackType
27
+ from typing import Any, Iterator, Self, TypeVar, cast
27
28
  from urllib.parse import urlparse
28
29
 
29
30
  import git
@@ -116,19 +117,19 @@ class StateTimer:
116
117
  logs: dict[str, dict[str, int | float]] = {}
117
118
 
118
119
  # Logs step statistics.
119
- logs[" steps"] = {
120
+ logs[" steps"] = {
120
121
  "total": self.step_timer.steps,
121
122
  "per-second": self.step_timer.steps_per_second,
122
123
  }
123
124
 
124
125
  # Logs sample statistics.
125
- logs[" samples"] = {
126
+ logs[" samples"] = {
126
127
  "total": self.sample_timer.steps,
127
128
  "per-second": self.sample_timer.steps_per_second,
128
129
  }
129
130
 
130
131
  # Logs full iteration statistics.
131
- logs["🔧 dt"] = {
132
+ logs[" dt"] = {
132
133
  "iter": self.iter_timer.iter_seconds,
133
134
  }
134
135
 
@@ -147,6 +148,24 @@ class IntervalTicker:
147
148
  return False
148
149
 
149
150
 
151
+ class ContextTimer:
152
+ def __init__(self) -> None:
153
+ self.start_time = 0.0
154
+ self.elapsed_time = 0.0
155
+
156
+ def __enter__(self) -> Self:
157
+ self.start_time = time.time()
158
+ return self
159
+
160
+ def __exit__(
161
+ self,
162
+ exc_type: type[BaseException] | None,
163
+ exc_value: BaseException | None,
164
+ traceback: TracebackType | None,
165
+ ) -> None:
166
+ self.elapsed_time = time.time() - self.start_time
167
+
168
+
150
169
  def abs_path(path: str) -> str:
151
170
  return str(Path(path).resolve())
152
171
 
xax/utils/jaxpr.py ADDED
@@ -0,0 +1,77 @@
1
+ """Visualize JAXPR."""
2
+
3
+ from pathlib import Path
4
+
5
+ import jax
6
+ import jax.core
7
+
8
+
9
+ def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
10
+ """Save the JAXPR to a DOT file.
11
+
12
+ Example usage:
13
+
14
+ grad_fn_jaxpr = jax.make_jaxpr(loss_fn)(variables)
15
+ save_jaxpr_dot(grad_fn_jaxpr, "grad_fn_jaxpr.dot")
16
+
17
+ Then, you can visualize the JAXPR using Graphviz:
18
+
19
+ dot -Tpng grad_fn_jaxpr.dot > grad_fn_jaxpr.png
20
+
21
+ Args:
22
+ closed_jaxpr: The closed JAXPR to save.
23
+ filename: The filename to save the JAXPR to.
24
+ """
25
+ if hasattr(closed_jaxpr, "jaxpr"):
26
+ jaxpr = closed_jaxpr.jaxpr
27
+ else:
28
+ jaxpr = closed_jaxpr
29
+
30
+ with open(filename, "w") as f:
31
+ f.write("digraph Jaxpr {\n")
32
+
33
+ var_names: dict[jax.core.Var, str] = {}
34
+ var_count = 0
35
+
36
+ def get_var_name(var: jax.core.Var) -> str:
37
+ """Get a unique name for a variable."""
38
+ nonlocal var_names, var_count
39
+
40
+ # Handle Literal objects specially since they're not hashable
41
+ if isinstance(var, jax.core.Literal):
42
+ # Create a name based on the literal value
43
+ name = f"lit_{var.val}"
44
+ return name
45
+
46
+ # For other variables
47
+ if var not in var_names:
48
+ name = f"var_{var_count}"
49
+ var_names[var] = name
50
+ var_count += 1
51
+ return var_names[var]
52
+
53
+ for var in jaxpr.invars:
54
+ node_name = get_var_name(var)
55
+ f.write(f' {node_name} [label="{node_name}\\n(input)"];\n')
56
+
57
+ eq_count = 0
58
+ for eq in jaxpr.eqns:
59
+ eq_node = f"eq{eq_count}"
60
+ label = f"{eq.primitive.name}"
61
+ f.write(f' {eq_node} [shape=box, label="{label}"];\n')
62
+
63
+ for invar in eq.invars:
64
+ var_name = get_var_name(invar)
65
+ f.write(f" {var_name} -> {eq_node};\n")
66
+
67
+ for outvar in eq.outvars:
68
+ var_name = get_var_name(outvar)
69
+ f.write(f" {eq_node} -> {var_name};\n")
70
+
71
+ eq_count += 1
72
+
73
+ for var in jaxpr.outvars:
74
+ node_name = get_var_name(var)
75
+ f.write(f' {node_name} [peripheries=2, label="{node_name}\\n(output)"];\n')
76
+
77
+ f.write("}\n")
xax/utils/logging.py CHANGED
@@ -140,7 +140,13 @@ class ColoredFormatter(logging.Formatter):
140
140
  return logging.Formatter.format(self, record)
141
141
 
142
142
 
143
- def configure_logging(prefix: str | None = None, *, rank: int | None = None, world_size: int | None = None) -> None:
143
+ def configure_logging(
144
+ prefix: str | None = None,
145
+ *,
146
+ rank: int | None = None,
147
+ world_size: int | None = None,
148
+ debug: bool | None = None,
149
+ ) -> None:
144
150
  """Instantiates logging.
145
151
 
146
152
  This captures logs and reroutes them to the Toasts module, which is
@@ -151,6 +157,7 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
151
157
  prefix: An optional prefix to add to the logger
152
158
  rank: The current rank, or None if not using multiprocessing
153
159
  world_size: The total world size, or None if not using multiprocessing
160
+ debug: Whether to enable debug logging
154
161
  """
155
162
  if rank is not None or world_size is not None:
156
163
  assert rank is not None and world_size is not None
@@ -168,7 +175,10 @@ def configure_logging(prefix: str | None = None, *, rank: int | None = None, wor
168
175
  stream_handler.addFilter(filter)
169
176
  root_logger.addHandler(stream_handler)
170
177
 
171
- root_logger.setLevel(logging._nameToLevel[config.log_level])
178
+ if debug is None:
179
+ root_logger.setLevel(logging._nameToLevel[config.log_level])
180
+ else:
181
+ root_logger.setLevel(logging.DEBUG if debug else logging.INFO)
172
182
 
173
183
  # Avoid junk logs from other libraries.
174
184
  if config.hide_third_party_logs: