xax 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/task/mixins/train.py CHANGED
@@ -24,6 +24,7 @@ from typing import (
24
24
  TypeVar,
25
25
  cast,
26
26
  get_args,
27
+ overload,
27
28
  )
28
29
 
29
30
  import equinox as eqx
@@ -35,6 +36,7 @@ from omegaconf import DictConfig
35
36
 
36
37
  from xax.core.conf import field
37
38
  from xax.core.state import Phase, State
39
+ from xax.nn.functions import set_random_seed
38
40
  from xax.nn.parallel import is_master
39
41
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
40
42
  from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
@@ -115,7 +117,7 @@ class ValidStepTimer:
115
117
  if self.last_valid_time is None or self.last_valid_step is None:
116
118
  self.last_valid_time = state.elapsed_time_s
117
119
  self.last_valid_step = state.num_steps
118
- return True
120
+ return False
119
121
 
120
122
  # Step-based validation.
121
123
  valid_every_n_steps = self.valid_every_n_steps
@@ -183,6 +185,9 @@ class TrainMixin(
183
185
  def __init__(self, config: Config) -> None:
184
186
  super().__init__(config)
185
187
 
188
+ # Sets the random seed whenever we instantiate a new train mixin.
189
+ set_random_seed(self.config.random_seed)
190
+
186
191
  # Timer for validation steps.
187
192
  self.valid_step_timer = ValidStepTimer(
188
193
  valid_every_n_steps=config.valid_every_n_steps,
@@ -279,31 +284,53 @@ class TrainMixin(
279
284
  def get_initial_opt_state(self, model: PyTree, optimizer: optax.GradientTransformation) -> optax.OptState:
280
285
  return optimizer.init(eqx.filter(model, eqx.is_array))
281
286
 
287
+ @overload
288
+ def load_initial_state(
289
+ self,
290
+ key: PRNGKeyArray,
291
+ load_optimizer: Literal[False] = False,
292
+ ) -> tuple[PyTree, State]: ...
293
+
294
+ @overload
282
295
  def load_initial_state(
283
296
  self,
284
297
  key: PRNGKeyArray,
285
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
298
+ load_optimizer: Literal[True],
299
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
300
+
301
+ def load_initial_state(
302
+ self,
303
+ key: PRNGKeyArray,
304
+ load_optimizer: bool = False,
305
+ ) -> tuple[PyTree, State] | tuple[PyTree, optax.GradientTransformation, optax.OptState, State]:
286
306
  init_ckpt_path = self.get_init_ckpt_path()
287
307
 
288
308
  if init_ckpt_path is not None:
289
309
  logger.info("Loading checkpoint from %s", init_ckpt_path)
290
- with self.step_context("load_checkpoint"):
310
+ if load_optimizer:
291
311
  model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
292
312
  config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
293
313
  if config_diff:
294
314
  logger.warning("Loaded config differs from current config:\n%s", config_diff)
295
315
  return model, optimizer, opt_state, state
296
316
 
297
- with self.step_context("get_model"):
298
- model = self.get_model(key)
317
+ else:
318
+ model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
319
+ config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
320
+ if config_diff:
321
+ logger.warning("Loaded config differs from current config:\n%s", config_diff)
322
+ return model, state
323
+
324
+ model = self.get_model(key)
325
+ state = State.init_state()
299
326
 
300
- with self.step_context("get_optimizer"):
301
- optimizer = self.get_optimizer()
327
+ if not load_optimizer:
328
+ return model, state
302
329
 
303
- with self.step_context("get_initial_opt_state"):
304
- opt_state = self.get_initial_opt_state(model, optimizer)
330
+ optimizer = self.get_optimizer()
331
+ opt_state = self.get_initial_opt_state(model, optimizer)
305
332
 
306
- return model, optimizer, opt_state, State.init_state()
333
+ return model, optimizer, opt_state, state
307
334
 
308
335
  @eqx.filter_jit
309
336
  def get_output(self, model: PyTree, batch: Batch) -> Output:
@@ -424,6 +451,7 @@ class TrainMixin(
424
451
  def log_state(self) -> None:
425
452
  logger.log(LOG_STATUS, self.task_path)
426
453
  logger.log(LOG_STATUS, self.task_name)
454
+ logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
427
455
  self.logger.log_file("git_state.txt", get_git_state(self))
428
456
  self.logger.log_file("training_code.txt", get_training_code(self))
429
457
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
@@ -456,7 +484,8 @@ class TrainMixin(
456
484
  while not self.is_training_over(state):
457
485
  if self.valid_step_timer.is_valid_step(state):
458
486
  valid_batch = next(valid_pf)
459
- model, loss, output = self.val_step(model, valid_batch)
487
+ with self.step_context("model_step"):
488
+ model, loss, output = self.val_step(model, valid_batch)
460
489
 
461
490
  # Perform logging.
462
491
  with self.step_context("write_logs"):
@@ -464,22 +493,19 @@ class TrainMixin(
464
493
  self.log_step(model, valid_batch, output, loss, state)
465
494
  state.num_valid_samples += 1
466
495
 
467
- with self.step_context("on_step_start"):
468
- state = self.on_step_start(state)
496
+ state = self.on_step_start(state)
469
497
 
470
- with self.step_context("update_state"):
498
+ with self.step_context("model_step"):
471
499
  train_batch = next(train_pf)
472
500
  model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
473
501
 
474
- # Perform logging.
475
502
  with self.step_context("write_logs"):
476
503
  state.phase = "train"
477
504
  self.log_step(model, train_batch, output, loss, state)
478
505
  state.num_steps += 1
479
506
  state.num_samples += self.get_size_of_batch(train_batch) or 0
480
507
 
481
- with self.step_context("on_step_end"):
482
- state = self.on_step_end(state)
508
+ state = self.on_step_end(state)
483
509
 
484
510
  if self.should_checkpoint(state):
485
511
  self.save_checkpoint(model, optimizer, opt_state, state)
@@ -496,14 +522,9 @@ class TrainMixin(
496
522
  except NotImplementedError:
497
523
  pass
498
524
 
499
- with self.step_context("get_dataset"):
500
- train_ds = self.get_dataset("train")
501
-
502
- with self.step_context("get_dataloader"):
503
- train_dl = self.get_dataloader(train_ds, "train")
504
-
505
- with self.step_context("get_prefetcher"):
506
- train_pf = self.get_prefetcher(train_dl)
525
+ train_ds = self.get_dataset("train")
526
+ train_dl = self.get_dataloader(train_ds, "train")
527
+ train_pf = self.get_prefetcher(train_dl)
507
528
 
508
529
  try:
509
530
  with train_pf as train_pf_ctx:
@@ -520,14 +541,9 @@ class TrainMixin(
520
541
  except NotImplementedError:
521
542
  pass
522
543
 
523
- with self.step_context("get_dataset"):
524
- valid_ds = self.get_dataset("valid")
525
-
526
- with self.step_context("get_dataloader"):
527
- valid_dl = self.get_dataloader(valid_ds, "valid")
528
-
529
- with self.step_context("get_prefetcher"):
530
- valid_pf = self.get_prefetcher(valid_dl)
544
+ valid_ds = self.get_dataset("valid")
545
+ valid_dl = self.get_dataloader(valid_ds, "valid")
546
+ valid_pf = self.get_prefetcher(valid_dl)
531
547
 
532
548
  try:
533
549
  with valid_pf as valid_pf_ctx:
@@ -559,7 +575,7 @@ class TrainMixin(
559
575
  Thread(target=self.log_state, daemon=True).start()
560
576
 
561
577
  key, model_key = jax.random.split(key)
562
- model, optimizer, opt_state, state = self.load_initial_state(model_key)
578
+ model, optimizer, opt_state, state = self.load_initial_state(model_key, load_optimizer=True)
563
579
  state = self.on_training_start(state)
564
580
 
565
581
  def on_exit() -> None:
xax/task/script.py CHANGED
@@ -17,8 +17,6 @@ from xax.task.mixins import (
17
17
  ProcessMixin,
18
18
  RunnableConfig,
19
19
  RunnableMixin,
20
- StepContextConfig,
21
- StepContextMixin,
22
20
  )
23
21
 
24
22
 
@@ -28,7 +26,6 @@ class ScriptConfig(
28
26
  GPUStatsConfig,
29
27
  ProcessConfig,
30
28
  LoggerConfig,
31
- StepContextConfig,
32
29
  ArtifactsConfig,
33
30
  RunnableConfig,
34
31
  BaseConfig,
@@ -44,7 +41,6 @@ class Script(
44
41
  GPUStatsMixin[ConfigT],
45
42
  ProcessMixin[ConfigT],
46
43
  LoggerMixin[ConfigT],
47
- StepContextMixin[ConfigT],
48
44
  ArtifactsMixin[ConfigT],
49
45
  RunnableMixin[ConfigT],
50
46
  BaseTask[ConfigT],
xax/utils/debugging.py ADDED
@@ -0,0 +1,49 @@
1
+ """Defines some useful Jax debugging utilities."""
2
+
3
+ from collections import deque
4
+ from collections.abc import Iterable, Mapping
5
+ from typing import Any, Callable, Deque
6
+
7
+ from jaxtyping import Array
8
+
9
+
10
+ def get_named_leaves(
11
+ obj: Any, # noqa: ANN401
12
+ is_leaf: Callable[[Any], bool] = lambda x: isinstance(x, Array), # noqa: ANN401
13
+ max_depth: int = 100,
14
+ ) -> list[tuple[str, Any]]: # noqa: ANN401
15
+ ret: list[tuple[str, Any]] = []
16
+ q: Deque[tuple[int, str, Any]] = deque() # noqa: ANN401
17
+ q.append((0, "", obj))
18
+
19
+ while q:
20
+ depth, name, node = q.popleft()
21
+
22
+ if depth > max_depth:
23
+ continue
24
+
25
+ if hasattr(node, "__dict__") and isinstance(node.__dict__, Mapping):
26
+ for cname, cnode in node.__dict__.items():
27
+ gname = f"{name}.{cname}" if name else cname
28
+ if is_leaf(cnode):
29
+ ret.append((gname, cnode))
30
+ else:
31
+ q.append((depth + 1, gname, cnode))
32
+
33
+ elif isinstance(node, Mapping):
34
+ for cname, cnode in node.items():
35
+ gname = f"{name}.{cname}" if name else cname
36
+ if is_leaf(cnode):
37
+ ret.append((gname, cnode))
38
+ else:
39
+ q.append((depth + 1, gname, cnode))
40
+
41
+ elif isinstance(node, Iterable):
42
+ for i, cnode in enumerate(node):
43
+ gname = f"{name}.{i}" if name else str(i)
44
+ if is_leaf(cnode):
45
+ ret.append((gname, cnode))
46
+ else:
47
+ q.append((depth + 1, gname, cnode))
48
+
49
+ return ret
xax/utils/experiments.py CHANGED
@@ -23,7 +23,8 @@ import urllib.request
23
23
  import warnings
24
24
  from abc import ABC, abstractmethod
25
25
  from pathlib import Path
26
- from typing import Any, Iterator, TypeVar, cast
26
+ from types import TracebackType
27
+ from typing import Any, Iterator, Self, TypeVar, cast
27
28
  from urllib.parse import urlparse
28
29
 
29
30
  import git
@@ -116,19 +117,19 @@ class StateTimer:
116
117
  logs: dict[str, dict[str, int | float]] = {}
117
118
 
118
119
  # Logs step statistics.
119
- logs[" steps"] = {
120
+ logs[" steps"] = {
120
121
  "total": self.step_timer.steps,
121
122
  "per-second": self.step_timer.steps_per_second,
122
123
  }
123
124
 
124
125
  # Logs sample statistics.
125
- logs[" samples"] = {
126
+ logs[" samples"] = {
126
127
  "total": self.sample_timer.steps,
127
128
  "per-second": self.sample_timer.steps_per_second,
128
129
  }
129
130
 
130
131
  # Logs full iteration statistics.
131
- logs["🔧 dt"] = {
132
+ logs[" dt"] = {
132
133
  "iter": self.iter_timer.iter_seconds,
133
134
  }
134
135
 
@@ -147,6 +148,24 @@ class IntervalTicker:
147
148
  return False
148
149
 
149
150
 
151
+ class ContextTimer:
152
+ def __init__(self) -> None:
153
+ self.start_time = 0.0
154
+ self.elapsed_time = 0.0
155
+
156
+ def __enter__(self) -> Self:
157
+ self.start_time = time.time()
158
+ return self
159
+
160
+ def __exit__(
161
+ self,
162
+ exc_type: type[BaseException] | None,
163
+ exc_value: BaseException | None,
164
+ traceback: TracebackType | None,
165
+ ) -> None:
166
+ self.elapsed_time = time.time() - self.start_time
167
+
168
+
150
169
  def abs_path(path: str) -> str:
151
170
  return str(Path(path).resolve())
152
171
 
xax/utils/jax.py CHANGED
@@ -1,14 +1,140 @@
1
1
  """Defines some utility functions for interfacing with Jax."""
2
2
 
3
+ import inspect
4
+ import logging
5
+ import os
6
+ import time
7
+ from functools import wraps
8
+ from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
9
+
10
+ import jax
3
11
  import jax.numpy as jnp
4
12
  import numpy as np
13
+ from jax._src import sharding_impls
14
+ from jax._src.lib import xla_client as xc
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ DEFAULT_COMPILE_TIMEOUT = 1.0
5
19
 
6
20
  Number = int | float | np.ndarray | jnp.ndarray
7
21
 
8
22
 
23
+ P = ParamSpec("P") # For function parameters
24
+ R = TypeVar("R") # For function return type
25
+
26
+
9
27
  def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
10
28
  if isinstance(value, (int, float)):
11
29
  return float(value)
12
30
  if isinstance(value, (np.ndarray, jnp.ndarray)):
13
31
  return float(value.item())
14
32
  raise TypeError(f"Unexpected type: {type(value)}")
33
+
34
+
35
+ def get_hash(obj: object) -> int:
36
+ """Get a hash of an object.
37
+
38
+ If the object is hashable, use the hash. Otherwise, use the id.
39
+ """
40
+ if hasattr(obj, "__hash__"):
41
+ return hash(obj)
42
+ return id(obj)
43
+
44
+
45
+ def jit(
46
+ in_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
47
+ out_shardings: Any = sharding_impls.UNSPECIFIED, # noqa: ANN401
48
+ static_argnums: int | Sequence[int] | None = None,
49
+ static_argnames: str | Iterable[str] | None = None,
50
+ donate_argnums: int | Sequence[int] | None = None,
51
+ donate_argnames: str | Iterable[str] | None = None,
52
+ keep_unused: bool = False,
53
+ device: xc.Device | None = None,
54
+ backend: str | None = None,
55
+ inline: bool = False,
56
+ abstracted_axes: Any | None = None, # noqa: ANN401
57
+ compiler_options: dict[str, Any] | None = None,
58
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
59
+ """Wrapper function that provides utility improvements over Jax's JIT.
60
+
61
+ Specifically, this function works on class methods, is toggleable, and
62
+ detects recompilations by matching hash values.
63
+
64
+ This is meant to be used as a decorator factory, and the decorated function
65
+ calls `wrapped`.
66
+ """
67
+
68
+ def decorator(fn: Callable[P, R]) -> Callable[P, R]:
69
+ class JitState:
70
+ compilation_count = 0
71
+ last_arg_dict: dict[str, int] | None = None
72
+
73
+ sig = inspect.signature(fn)
74
+ param_names = list(sig.parameters.keys())
75
+
76
+ jitted_fn = jax.jit(
77
+ fn,
78
+ in_shardings=in_shardings,
79
+ out_shardings=out_shardings,
80
+ static_argnums=static_argnums,
81
+ static_argnames=static_argnames,
82
+ donate_argnums=donate_argnums,
83
+ donate_argnames=donate_argnames,
84
+ keep_unused=keep_unused,
85
+ device=device,
86
+ backend=backend,
87
+ inline=inline,
88
+ abstracted_axes=abstracted_axes,
89
+ compiler_options=compiler_options,
90
+ )
91
+
92
+ @wraps(fn)
93
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
94
+ if os.environ.get("DEBUG", "0") == "1": # skipping during debug
95
+ return fn(*args, **kwargs)
96
+
97
+ do_profile = os.environ.get("JIT_PROFILE", "0") == "1"
98
+
99
+ if do_profile:
100
+ class_name = (args[0].__class__.__name__) + "." if fn.__name__ == "__call__" else ""
101
+ logger.info(
102
+ "Currently running %s (count: %s)",
103
+ f"{class_name}{fn.__name__}",
104
+ JitState.compilation_count,
105
+ )
106
+
107
+ start_time = time.time()
108
+ res = jitted_fn(*args, **kwargs)
109
+ end_time = time.time()
110
+ runtime = end_time - start_time
111
+
112
+ # if this is true, if runtime is higher than COMPILE_TIMEOUT, we recompile
113
+ # TODO: we should probably reimplement the lower-level jitting logic to avoid this
114
+ if do_profile:
115
+ arg_dict = {}
116
+ for i, arg in enumerate(args):
117
+ if i < len(param_names):
118
+ arg_dict[param_names[i]] = get_hash(arg)
119
+ for k, v in kwargs.items():
120
+ arg_dict[k] = get_hash(v)
121
+
122
+ logger.info("Hashing took %s seconds", runtime)
123
+ JitState.compilation_count += 1
124
+
125
+ if JitState.last_arg_dict is not None:
126
+ all_keys = set(arg_dict.keys()) | set(JitState.last_arg_dict.keys())
127
+ for k in all_keys:
128
+ prev = JitState.last_arg_dict.get(k, "N/A")
129
+ curr = arg_dict.get(k, "N/A")
130
+
131
+ if prev != curr:
132
+ logger.info("- Arg '%s' hash changed: %s -> %s", k, prev, curr)
133
+
134
+ JitState.last_arg_dict = arg_dict
135
+
136
+ return cast(R, res)
137
+
138
+ return wrapped
139
+
140
+ return decorator
xax/utils/jaxpr.py ADDED
@@ -0,0 +1,77 @@
1
+ """Visualize JAXPR."""
2
+
3
+ from pathlib import Path
4
+
5
+ import jax
6
+ import jax.core
7
+
8
+
9
+ def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
10
+ """Save the JAXPR to a DOT file.
11
+
12
+ Example usage:
13
+
14
+ grad_fn_jaxpr = jax.make_jaxpr(loss_fn)(variables)
15
+ save_jaxpr_dot(grad_fn_jaxpr, "grad_fn_jaxpr.dot")
16
+
17
+ Then, you can visualize the JAXPR using Graphviz:
18
+
19
+ dot -Tpng grad_fn_jaxpr.dot > grad_fn_jaxpr.png
20
+
21
+ Args:
22
+ closed_jaxpr: The closed JAXPR to save.
23
+ filename: The filename to save the JAXPR to.
24
+ """
25
+ if hasattr(closed_jaxpr, "jaxpr"):
26
+ jaxpr = closed_jaxpr.jaxpr
27
+ else:
28
+ jaxpr = closed_jaxpr
29
+
30
+ with open(filename, "w") as f:
31
+ f.write("digraph Jaxpr {\n")
32
+
33
+ var_names: dict[jax.core.Var, str] = {}
34
+ var_count = 0
35
+
36
+ def get_var_name(var: jax.core.Var) -> str:
37
+ """Get a unique name for a variable."""
38
+ nonlocal var_names, var_count
39
+
40
+ # Handle Literal objects specially since they're not hashable
41
+ if isinstance(var, jax.core.Literal):
42
+ # Create a name based on the literal value
43
+ name = f"lit_{var.val}"
44
+ return name
45
+
46
+ # For other variables
47
+ if var not in var_names:
48
+ name = f"var_{var_count}"
49
+ var_names[var] = name
50
+ var_count += 1
51
+ return var_names[var]
52
+
53
+ for var in jaxpr.invars:
54
+ node_name = get_var_name(var)
55
+ f.write(f' {node_name} [label="{node_name}\\n(input)"];\n')
56
+
57
+ eq_count = 0
58
+ for eq in jaxpr.eqns:
59
+ eq_node = f"eq{eq_count}"
60
+ label = f"{eq.primitive.name}"
61
+ f.write(f' {eq_node} [shape=box, label="{label}"];\n')
62
+
63
+ for invar in eq.invars:
64
+ var_name = get_var_name(invar)
65
+ f.write(f" {var_name} -> {eq_node};\n")
66
+
67
+ for outvar in eq.outvars:
68
+ var_name = get_var_name(outvar)
69
+ f.write(f" {eq_node} -> {var_name};\n")
70
+
71
+ eq_count += 1
72
+
73
+ for var in jaxpr.outvars:
74
+ node_name = get_var_name(var)
75
+ f.write(f' {node_name} [peripheries=2, label="{node_name}\\n(output)"];\n')
76
+
77
+ f.write("}\n")
xax/utils/profile.py ADDED
@@ -0,0 +1,61 @@
1
+ """Profiling utilities."""
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ from functools import wraps
7
+ from typing import Callable, ParamSpec, TypeVar
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ P = ParamSpec("P") # For function parameters
12
+ R = TypeVar("R") # For function return type
13
+
14
+
15
+ def profile(fn: Callable[P, R]) -> Callable[P, R]:
16
+ """Profiling decorator that tracks function call count and execution time.
17
+
18
+ Activated when the PROFILE environment variable is set to "1".
19
+
20
+ Returns:
21
+ A decorated function with profiling capabilities.
22
+ """
23
+
24
+ class ProfileState:
25
+ call_count = 0
26
+ total_time = 0.0
27
+
28
+ @wraps(fn)
29
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
30
+ if os.environ.get("PROFILE", "0") != "1":
31
+ return fn(*args, **kwargs)
32
+
33
+ start_time = time.time()
34
+ res = fn(*args, **kwargs)
35
+ end_time = time.time()
36
+ runtime = end_time - start_time
37
+
38
+ ProfileState.call_count += 1
39
+ ProfileState.total_time += runtime
40
+
41
+ # Handle class methods by showing class name
42
+ if fn.__name__ == "__call__" or (args and hasattr(args[0], "__class__")):
43
+ try:
44
+ class_name = args[0].__class__.__name__ + "."
45
+ except (IndexError, AttributeError):
46
+ class_name = ""
47
+ else:
48
+ class_name = ""
49
+
50
+ logger.info(
51
+ "%s %s - call #%s, took %s seconds, total: %s seconds",
52
+ class_name,
53
+ fn.__name__,
54
+ ProfileState.call_count,
55
+ runtime,
56
+ ProfileState.total_time,
57
+ )
58
+
59
+ return res
60
+
61
+ return wrapped