xax 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.2"
15
+ __version__ = "0.2.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -104,6 +104,7 @@ __all__ = [
104
104
  "stage_environment",
105
105
  "to_markdown_table",
106
106
  "jit",
107
+ "scan",
107
108
  "save_jaxpr_dot",
108
109
  "ColoredFormatter",
109
110
  "configure_logging",
@@ -267,6 +268,7 @@ NAME_MAP: dict[str, str] = {
267
268
  "stage_environment": "utils.experiments",
268
269
  "to_markdown_table": "utils.experiments",
269
270
  "jit": "utils.jax",
271
+ "scan": "utils.jax",
270
272
  "save_jaxpr_dot": "utils.jaxpr",
271
273
  "ColoredFormatter": "utils.logging",
272
274
  "configure_logging": "utils.logging",
@@ -422,7 +424,7 @@ if IMPORT_ALL or TYPE_CHECKING:
422
424
  stage_environment,
423
425
  to_markdown_table,
424
426
  )
425
- from xax.utils.jax import jit
427
+ from xax.utils.jax import jit, scan
426
428
  from xax.utils.jaxpr import save_jaxpr_dot
427
429
  from xax.utils.logging import (
428
430
  LOG_ERROR_SUMMARY,
xax/core/state.py CHANGED
@@ -1,10 +1,12 @@
1
1
  """Defines a dataclass for keeping track of the current training state."""
2
2
 
3
3
  import time
4
- from dataclasses import asdict, dataclass
5
- from typing import Any, Literal, NotRequired, TypedDict, Unpack, cast
4
+ from dataclasses import dataclass
5
+ from typing import Literal, NotRequired, TypedDict, Unpack, cast
6
6
 
7
7
  import jax
8
+ import jax.numpy as jnp
9
+ from jaxtyping import Array
8
10
  from omegaconf import MISSING
9
11
 
10
12
  from xax.core.conf import field
@@ -21,67 +23,89 @@ def _int_to_phase(i: int) -> Phase:
21
23
 
22
24
 
23
25
  class StateDict(TypedDict, total=False):
24
- num_steps: NotRequired[int]
25
- num_samples: NotRequired[int]
26
- num_valid_steps: NotRequired[int]
27
- num_valid_samples: NotRequired[int]
28
- start_time_s: NotRequired[float]
29
- elapsed_time_s: NotRequired[float]
26
+ num_steps: NotRequired[int | Array]
27
+ num_samples: NotRequired[int | Array]
28
+ num_valid_steps: NotRequired[int | Array]
29
+ num_valid_samples: NotRequired[int | Array]
30
+ start_time_s: NotRequired[float | Array]
31
+ elapsed_time_s: NotRequired[float | Array]
30
32
  phase: NotRequired[Phase]
33
+ _phase: NotRequired[int | Array]
31
34
 
32
35
 
33
36
  @jax.tree_util.register_dataclass
34
37
  @dataclass(frozen=True, kw_only=True)
35
38
  class State:
36
- num_steps: int = field(MISSING, help="Number of steps so far")
37
- num_samples: int = field(MISSING, help="Number of sample so far")
38
- num_valid_steps: int = field(MISSING, help="Number of validation steps so far")
39
- num_valid_samples: int = field(MISSING, help="Number of validation samples so far")
40
- start_time_s: float = field(MISSING, help="Start time of training")
41
- elapsed_time_s: float = field(MISSING, help="Total elapsed time so far")
42
- _phase: int = field(MISSING, help="Current training phase")
39
+ _int32_arr: Array = field(MISSING, help="Internal array for storing int64 values")
40
+ _float32_arr: Array = field(MISSING, help="Internal array for storing floating-point values")
41
+
42
+ @property
43
+ def num_steps(self) -> Array:
44
+ return self._int32_arr[0]
45
+
46
+ @property
47
+ def num_samples(self) -> Array:
48
+ return self._float32_arr[0]
49
+
50
+ @property
51
+ def num_valid_steps(self) -> Array:
52
+ return self._int32_arr[1]
53
+
54
+ @property
55
+ def num_valid_samples(self) -> Array:
56
+ return self._float32_arr[1]
57
+
58
+ @property
59
+ def start_time_s(self) -> Array:
60
+ return self._float32_arr[2]
61
+
62
+ @property
63
+ def elapsed_time_s(self) -> Array:
64
+ return self._float32_arr[3]
43
65
 
44
66
  @property
45
67
  def phase(self) -> Phase:
46
- return _int_to_phase(self._phase)
68
+ return _int_to_phase(self._int32_arr[2].item())
47
69
 
48
70
  @classmethod
49
71
  def init_state(cls) -> "State":
50
72
  return cls(
51
- num_steps=0,
52
- num_samples=0,
53
- num_valid_steps=0,
54
- num_valid_samples=0,
55
- start_time_s=time.time(),
56
- elapsed_time_s=0.0,
57
- _phase=0,
73
+ _int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
74
+ _float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0], dtype=jnp.float32),
58
75
  )
59
76
 
60
77
  @property
61
78
  def training(self) -> bool:
62
79
  return self.phase == "train"
63
80
 
64
- def num_phase_steps(self, phase: Phase) -> int:
65
- match phase:
66
- case "train":
67
- return self.num_steps
68
- case "valid":
69
- return self.num_valid_steps
70
- case _:
71
- raise ValueError(f"Invalid phase: {phase}")
72
-
73
81
  def replace(self, **kwargs: Unpack[StateDict]) -> "State":
74
- extra_kwargs: dict[str, Any] = {} # noqa: ANN401
82
+ int32_arr = self._int32_arr
83
+ float32_arr = self._float32_arr
84
+
85
+ if "num_steps" in kwargs:
86
+ int32_arr = int32_arr.at[0].set(kwargs["num_steps"])
87
+ if "num_valid_steps" in kwargs:
88
+ int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
89
+
75
90
  if "phase" in kwargs:
76
- phase = kwargs.pop("phase")
77
- match phase:
78
- case "train":
79
- extra_kwargs["_phase"] = 0
80
- case "valid":
81
- extra_kwargs["_phase"] = 1
82
- case _:
83
- raise ValueError(f"Invalid phase: {phase}")
84
- return State(**{**asdict(self), **kwargs, **extra_kwargs})
91
+ int32_arr = int32_arr.at[3].set(_phase_to_int(kwargs["phase"]))
92
+ if "_phase" in kwargs:
93
+ int32_arr = int32_arr.at[3].set(kwargs["_phase"])
94
+
95
+ if "num_samples" in kwargs:
96
+ float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
97
+ if "num_valid_samples" in kwargs:
98
+ float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
99
+
100
+ if "start_time_s" in kwargs:
101
+ float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
102
+ if "elapsed_time_s" in kwargs:
103
+ float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
104
+
105
+ return State(
106
+ _int32_arr=int32_arr,
107
+ _float32_arr=float32_arr,
108
+ )
85
109
 
86
110
  def to_dict(self) -> dict[str, int | float | str]:
87
111
  return {
@@ -95,7 +119,30 @@ class State:
95
119
  }
96
120
 
97
121
  @classmethod
98
- def from_dict(cls, d: dict[str, int | float | str]) -> "State":
122
+ def from_dict(cls, **d: Unpack[StateDict]) -> "State":
99
123
  if "phase" in d:
100
124
  d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
101
- return cls(**d) # type: ignore[arg-type]
125
+
126
+ int32_arr = jnp.array(
127
+ [
128
+ d.get("num_steps", 0),
129
+ d.get("num_samples", 0),
130
+ d.get("num_valid_steps", 0),
131
+ d.get("num_valid_samples", 0),
132
+ d.get("_phase", 0),
133
+ ],
134
+ dtype=jnp.int32,
135
+ )
136
+
137
+ float32_arr = jnp.array(
138
+ [
139
+ d.get("start_time_s", time.time()),
140
+ d.get("elapsed_time_s", 0.0),
141
+ ],
142
+ dtype=jnp.float32,
143
+ )
144
+
145
+ return cls(
146
+ _int32_arr=int32_arr,
147
+ _float32_arr=float32_arr,
148
+ )
xax/task/logger.py CHANGED
@@ -521,7 +521,7 @@ class LoggerImpl(ABC):
521
521
  Returns:
522
522
  If the logger should log the current step.
523
523
  """
524
- return self.tickers[state.phase].tick(state.elapsed_time_s)
524
+ return self.tickers[state.phase].tick(state.elapsed_time_s.item())
525
525
 
526
526
 
527
527
  class ToastHandler(logging.Handler):
@@ -90,9 +90,9 @@ class StdoutLogger(LoggerImpl):
90
90
 
91
91
  def write_state_window(self, line: LogLine) -> None:
92
92
  state_info: dict[str, str] = {
93
- "Steps": format_number(line.state.num_steps, 0),
94
- "Samples": format_number(line.state.num_samples, 0),
95
- "Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True),
93
+ "Steps": format_number(int(line.state.num_steps.item()), 0),
94
+ "Samples": format_number(int(line.state.num_samples.item()), 0),
95
+ "Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s.item()), short=True),
96
96
  }
97
97
 
98
98
  colored_prefix = colored("Phase: ", "grey", bold=True)
@@ -155,14 +155,16 @@ class TensorboardLogger(LoggerImpl):
155
155
  return
156
156
 
157
157
  writer = self.get_writer(line.state.phase)
158
- walltime = line.state.start_time_s + line.state.elapsed_time_s
158
+
159
+ global_step = line.state.num_steps.item()
160
+ walltime = (line.state.start_time_s + line.state.elapsed_time_s).item()
159
161
 
160
162
  for namespace, scalars in line.scalars.items():
161
163
  for scalar_key, scalar_value in scalars.items():
162
164
  writer.add_scalar(
163
165
  f"{namespace}/{scalar_key}",
164
166
  as_float(scalar_value.value),
165
- global_step=line.state.num_steps,
167
+ global_step=global_step,
166
168
  walltime=walltime,
167
169
  )
168
170
 
@@ -172,7 +174,7 @@ class TensorboardLogger(LoggerImpl):
172
174
  f"{namespace}/{distribution_key}",
173
175
  mean=float(distribution_value.mean),
174
176
  std=float(distribution_value.std),
175
- global_step=line.state.num_steps,
177
+ global_step=global_step,
176
178
  walltime=walltime,
177
179
  )
178
180
 
@@ -187,7 +189,7 @@ class TensorboardLogger(LoggerImpl):
187
189
  sum_squares=float(histogram_value.sum_squares),
188
190
  bucket_limits=[float(x) for x in histogram_value.bucket_limits],
189
191
  bucket_counts=[int(x) for x in histogram_value.bucket_counts],
190
- global_step=line.state.num_steps,
192
+ global_step=global_step,
191
193
  walltime=walltime,
192
194
  )
193
195
 
@@ -196,7 +198,7 @@ class TensorboardLogger(LoggerImpl):
196
198
  writer.add_text(
197
199
  f"{namespace}/{string_key}",
198
200
  string_value.value,
199
- global_step=line.state.num_steps,
201
+ global_step=global_step,
200
202
  walltime=walltime,
201
203
  )
202
204
 
@@ -205,7 +207,7 @@ class TensorboardLogger(LoggerImpl):
205
207
  writer.add_image(
206
208
  f"{namespace}/{image_key}",
207
209
  image_value.image,
208
- global_step=line.state.num_steps,
210
+ global_step=global_step,
209
211
  walltime=walltime,
210
212
  )
211
213
 
@@ -215,7 +217,7 @@ class TensorboardLogger(LoggerImpl):
215
217
  f"{namespace}/{video_key}",
216
218
  video_value.frames,
217
219
  fps=video_value.fps,
218
- global_step=line.state.num_steps,
220
+ global_step=global_step,
219
221
  walltime=walltime,
220
222
  )
221
223
 
@@ -227,7 +229,7 @@ class TensorboardLogger(LoggerImpl):
227
229
  faces=mesh_value.faces,
228
230
  colors=mesh_value.colors,
229
231
  config_dict=mesh_value.config_dict,
230
- global_step=line.state.num_steps,
232
+ global_step=global_step,
231
233
  walltime=walltime,
232
234
  )
233
235
 
@@ -4,7 +4,7 @@ import io
4
4
  import json
5
5
  import logging
6
6
  import tarfile
7
- from dataclasses import asdict, dataclass
7
+ from dataclasses import dataclass
8
8
  from pathlib import Path
9
9
  from typing import Generic, Literal, TypeVar, cast, overload
10
10
 
@@ -76,7 +76,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
76
76
  if state.num_steps % self.config.save_every_n_steps == 0:
77
77
  return True
78
78
  if self.config.save_every_n_seconds is not None:
79
- last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s
79
+ last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s.item()
80
80
  if cur_time - last_time >= self.config.save_every_n_seconds:
81
81
  self.__last_ckpt_time = cur_time
82
82
  return True
@@ -200,7 +200,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
200
200
  def get_state() -> State:
201
201
  if (state := tar.extractfile("state")) is None:
202
202
  raise ValueError(f"Checkpoint does not contain a state file: {path}")
203
- return State(**json.loads(state.read().decode()))
203
+ return State.from_dict(**json.loads(state.read().decode()))
204
204
 
205
205
  def get_config() -> Config:
206
206
  if (config := tar.extractfile("config")) is None:
@@ -300,7 +300,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
300
300
  tar.addfile(info, io.BytesIO(data))
301
301
 
302
302
  if state is not None:
303
- add_file_bytes("state", json.dumps(asdict(state), indent=2).encode())
303
+ add_file_bytes("state", json.dumps(state.to_dict(), indent=2).encode())
304
304
  add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
305
305
 
306
306
  # Updates the symlink to the new checkpoint
xax/task/mixins/train.py CHANGED
@@ -46,6 +46,7 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
46
46
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
47
47
  from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
48
48
  from xax.utils.experiments import (
49
+ ContextTimer,
49
50
  StateTimer,
50
51
  TrainingFinishedError,
51
52
  diff_configs,
@@ -119,27 +120,30 @@ class ValidStepTimer:
119
120
  return True
120
121
 
121
122
  if self.last_valid_time is None or self.last_valid_step is None:
122
- self.last_valid_time = state.elapsed_time_s
123
- self.last_valid_step = state.num_steps
123
+ self.last_valid_time = state.elapsed_time_s.item()
124
+ self.last_valid_step = state.num_steps.item()
124
125
  return False
125
126
 
126
127
  # Step-based validation.
127
128
  valid_every_n_steps = self.valid_every_n_steps
128
129
  if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
129
- self.last_valid_step = state.num_steps
130
+ self.last_valid_step = state.num_steps.item()
130
131
  return True
131
132
 
132
133
  # Time-based validation.
133
134
  valid_every_n_seconds = self.valid_every_n_seconds
134
- if valid_every_n_seconds is not None and state.elapsed_time_s - self.last_valid_time >= valid_every_n_seconds:
135
- self.last_valid_time = state.elapsed_time_s
135
+ if (
136
+ valid_every_n_seconds is not None
137
+ and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
138
+ ):
139
+ self.last_valid_time = state.elapsed_time_s.item()
136
140
  return True
137
141
 
138
142
  # Time-based validation for first validation step.
139
143
  if self.first_valid_step_flag:
140
144
  valid_first_n_seconds = self.valid_first_n_seconds
141
- if valid_first_n_seconds is not None and state.elapsed_time_s >= valid_first_n_seconds:
142
- self.last_valid_time = state.elapsed_time_s
145
+ if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
146
+ self.last_valid_time = state.elapsed_time_s.item()
143
147
  self.first_valid_step_flag = False
144
148
  return True
145
149
 
@@ -214,10 +218,6 @@ class TrainMixin(
214
218
  def prng_key(self) -> PRNGKeyArray:
215
219
  return jax.random.PRNGKey(self.config.random_seed)
216
220
 
217
- def on_step_end(self, state: State) -> State:
218
- state = super().on_step_end(state)
219
- return state.replace(elapsed_time_s=time.time() - state.start_time_s)
220
-
221
221
  def log_train_step(
222
222
  self,
223
223
  model: PyTree,
@@ -548,7 +548,7 @@ class TrainMixin(
548
548
  "loss": loss,
549
549
  }
550
550
 
551
- @xax_jit(static_argnames=["self", "model_static"])
551
+ @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
552
552
  def get_output_and_loss(
553
553
  self,
554
554
  model_arr: PyTree,
@@ -572,12 +572,12 @@ class TrainMixin(
572
572
  state: State,
573
573
  ) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
574
574
  grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
575
- grad_fn = xax_jit(static_argnums=[1])(grad_fn)
575
+ grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
576
576
  grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
577
577
  model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
578
578
  return model_arr, opt_state, output, metrics | grad_metrics
579
579
 
580
- @xax_jit(static_argnames=["self", "optimizer"])
580
+ @xax_jit(static_argnames=["self", "optimizer"], jit_level=3)
581
581
  def apply_gradients_with_clipping(
582
582
  self,
583
583
  model_arr: PyTree,
@@ -641,8 +641,8 @@ class TrainMixin(
641
641
  def maybe_log_termination_time(self, remaining_percent: float, state: State) -> None:
642
642
  if self._last_printed_remaining_time + PRINT_FINISH_TIME_EVERY_N_SECONDS > state.elapsed_time_s:
643
643
  return
644
- self._last_printed_remaining_time = state.elapsed_time_s
645
- remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
644
+ self._last_printed_remaining_time = state.elapsed_time_s.item()
645
+ remaining_seconds = remaining_percent * state.elapsed_time_s.item() / (1 - remaining_percent)
646
646
  termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
647
647
  logger.log(LOG_PING, "Estimated finish time: %s", termination_time)
648
648
 
@@ -663,11 +663,11 @@ class TrainMixin(
663
663
  def get_step(self, state: State) -> int:
664
664
  match self._step_kind:
665
665
  case "step":
666
- return state.num_steps
666
+ return int(state.num_steps.item())
667
667
  case "sample":
668
- return state.num_samples
668
+ return int(state.num_samples.item())
669
669
  case "second":
670
- return int(state.elapsed_time_s)
670
+ return int(state.elapsed_time_s.item())
671
671
  case _:
672
672
  raise ValueError(f"Invalid step kind {self._step_kind}")
673
673
 
@@ -683,7 +683,7 @@ class TrainMixin(
683
683
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
684
684
  return eqx.is_inexact_array(item)
685
685
 
686
- @xax_jit(static_argnames=["self", "model_static", "optimizer"])
686
+ @xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
687
687
  def train_step(
688
688
  self,
689
689
  model_arr: PyTree,
@@ -696,7 +696,7 @@ class TrainMixin(
696
696
  model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
697
697
  return model_arr, opt_state, output, FrozenDict(metrics)
698
698
 
699
- @xax_jit(static_argnames=["self", "model_static"])
699
+ @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
700
700
  def val_step(
701
701
  self,
702
702
  model_arr: PyTree,
@@ -732,22 +732,25 @@ class TrainMixin(
732
732
 
733
733
  state = self.on_step_start(state)
734
734
  train_batch = next(train_pf)
735
+
736
+ with ContextTimer() as timer:
737
+ model_arr, opt_state, output, metrics = self.train_step(
738
+ model_arr=model_arr,
739
+ model_static=model_static,
740
+ optimizer=optimizer,
741
+ opt_state=opt_state,
742
+ batch=train_batch,
743
+ state=state,
744
+ )
745
+ self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
746
+
735
747
  state = state.replace(
736
748
  phase="train",
737
749
  num_steps=state.num_steps + 1,
738
750
  num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
751
+ elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
739
752
  )
740
753
 
741
- model_arr, opt_state, output, metrics = self.train_step(
742
- model_arr=model_arr,
743
- model_static=model_static,
744
- optimizer=optimizer,
745
- opt_state=opt_state,
746
- batch=train_batch,
747
- state=state,
748
- )
749
- self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
750
-
751
754
  state = self.on_step_end(state)
752
755
 
753
756
  if self.should_checkpoint(state):
@@ -843,10 +846,8 @@ class TrainMixin(
843
846
 
844
847
  except TrainingFinishedError:
845
848
  if is_master():
846
- show_info(
847
- f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
848
- important=True,
849
- )
849
+ num_steps, num_samples = int(state.num_steps), int(state.num_samples)
850
+ show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
850
851
  self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
851
852
 
852
853
  except (KeyboardInterrupt, bdb.BdbQuit):
xax/utils/experiments.py CHANGED
@@ -111,8 +111,10 @@ class StateTimer:
111
111
 
112
112
  def step(self, state: State) -> None:
113
113
  cur_time = time.time()
114
- self.step_timer.step(state.num_steps if state.phase == "train" else state.num_valid_steps, cur_time)
115
- self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
114
+ num_steps = int((state.num_steps if state.phase == "train" else state.num_valid_steps).item())
115
+ num_samples = int((state.num_samples if state.phase == "train" else state.num_valid_samples).item())
116
+ self.step_timer.step(num_steps, cur_time)
117
+ self.sample_timer.step(num_samples, cur_time)
116
118
  self.iter_timer.step(cur_time)
117
119
 
118
120
  def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
xax/utils/jax.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Defines some utility functions for interfacing with Jax."""
2
2
 
3
+ import functools
3
4
  import inspect
4
5
  import logging
5
6
  import os
@@ -23,6 +24,28 @@ Number = int | float | np.ndarray | jnp.ndarray
23
24
  P = ParamSpec("P") # For function parameters
24
25
  R = TypeVar("R") # For function return type
25
26
 
27
+ # For control flow functions.
28
+ Carry = TypeVar("Carry")
29
+ X = TypeVar("X")
30
+ Y = TypeVar("Y")
31
+
32
+
33
+ @functools.lru_cache(maxsize=None)
34
+ def disable_jit_level() -> int:
35
+ """Gets a debugging flag for disabling jitting.
36
+
37
+ For Xax's JIT'ed functions, we can set a JIT level which can be used to
38
+ disable jitting when we want to debug some NaN issues.
39
+
40
+ Returns:
41
+ The JIT level to disable.
42
+ """
43
+ return int(os.environ.get("DISABLE_JIT_LEVEL", "0"))
44
+
45
+
46
+ def should_disable_jit(jit_level: int | None) -> bool:
47
+ return jit_level is not None and jit_level < disable_jit_level()
48
+
26
49
 
27
50
  def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
28
51
  if isinstance(value, (int, float)):
@@ -55,6 +78,7 @@ def jit(
55
78
  inline: bool = False,
56
79
  abstracted_axes: Any | None = None, # noqa: ANN401
57
80
  compiler_options: dict[str, Any] | None = None,
81
+ jit_level: int | None = None,
58
82
  ) -> Callable[[Callable[P, R]], Callable[P, R]]:
59
83
  """Wrapper function that provides utility improvements over Jax's JIT.
60
84
 
@@ -64,6 +88,8 @@ def jit(
64
88
  This is meant to be used as a decorator factory, and the decorated function
65
89
  calls `wrapped`.
66
90
  """
91
+ if should_disable_jit(jit_level):
92
+ return lambda fn: fn # Identity function.
67
93
 
68
94
  def decorator(fn: Callable[P, R]) -> Callable[P, R]:
69
95
  class JitState:
@@ -138,3 +164,46 @@ def jit(
138
164
  return wrapped
139
165
 
140
166
  return decorator
167
+
168
+
169
+ def scan(
170
+ f: Callable[[Carry, X], tuple[Carry, Y]],
171
+ init: Carry,
172
+ xs: X | None = None,
173
+ length: int | None = None,
174
+ reverse: bool = False,
175
+ unroll: int | bool = 1,
176
+ jit_level: int | None = None,
177
+ ) -> tuple[Carry, Y]:
178
+ """A wrapper around jax.lax.scan that allows for more flexible tracing.
179
+
180
+ If the provided JIT level is below the environment JIT level, we manually
181
+ unroll the scan function as a for loop.
182
+
183
+ Args:
184
+ f: The function to scan.
185
+ init: The initial value for the scan.
186
+ xs: The input to the scan.
187
+ length: The length of the scan.
188
+ reverse: Whether to reverse the scan.
189
+ unroll: The unroll factor for the scan.
190
+ jit_level: The JIT level to use for the scan.
191
+
192
+ Returns:
193
+ A tuple containing the final carry and the output of the scan.
194
+ """
195
+ if not should_disable_jit(jit_level):
196
+ return jax.lax.scan(f, init, xs, length, reverse, unroll)
197
+
198
+ if xs is None:
199
+ if length is None:
200
+ raise ValueError("length must be provided if xs is None")
201
+ xs = cast(X, [None] * length)
202
+
203
+ carry = init
204
+ ys = []
205
+ for x in cast(Iterable, xs):
206
+ carry, y = f(carry, x)
207
+ ys.append(y)
208
+
209
+ return carry, jax.tree.map(lambda *ys: jnp.stack(ys), *ys)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,10 +1,10 @@
1
- xax/__init__.py,sha256=Yj2SgoKyIAQzg3bt-hAS4gf0fqlfVBR4pv4JgpTl7-s,14182
1
+ xax/__init__.py,sha256=P4q2IGkfpHaN3ZlGFiW0bzWm1spLSUyl0GEPvH8oITg,14225
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
5
5
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
7
- xax/core/state.py,sha256=XejW1tGINYFFcNrscK8eZQsq02J7_RXa461QpmyWuLk,3337
7
+ xax/core/state.py,sha256=lA7A5HCm2Nwk4J0kJGlTIhqHYFvbuwHfdNzOhmjEW08,4453
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
@@ -17,7 +17,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
17
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  xax/task/base.py,sha256=OnXi2hiKPGwt6ng1dutnoQSiw7lEiWFlC_vx99_JsbQ,7694
20
- xax/task/logger.py,sha256=peGtfnvnBKr9l6tx1V6XAsvPs0HP6ubV_aE7IJtOMNk,40868
20
+ xax/task/logger.py,sha256=y4PGfMqKbfvPk8WCzr9MOsgG2X9E61KgeBVOYp-9kOY,40875
21
21
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
22
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
23
23
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -28,11 +28,11 @@ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
28
28
  xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
29
29
  xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
30
30
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
31
- xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,6564
32
- xax/task/loggers/tensorboard.py,sha256=KOL9l60tLctX-VAdNwe49H48SAJeGxph3sflJpojA-4,8337
31
+ xax/task/loggers/stdout.py,sha256=ERLFrYe61hSSztzyxBRseobHQR72YFYjEd2i_hOeJ20,6595
32
+ xax/task/loggers/tensorboard.py,sha256=3ohI6STgSCbU8oyeiH_f3QyLVF_zO_6dwjn0ns59rUU,8334
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
- xax/task/mixins/checkpointing.py,sha256=2nJgqFcV-D8W-4j8TR3PvVh1g5hQUOo-_quKO-XlE4U,11398
35
+ xax/task/mixins/checkpointing.py,sha256=8Hi-2G0EA5OFRjgiOutlk7HgkD5b-0GHazOAYxnGytM,11409
36
36
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
37
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
38
38
  xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
@@ -41,11 +41,11 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
41
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=v9oi9tNsNBYo-Ne_98nCG9qHX6sxvymHjsRDnL6GL-U,30871
44
+ xax/task/mixins/train.py,sha256=lMHCnxsbZJbwK3esL5S3cJ0Jf5Qx19Y4pm3A7NY-TIE,31064
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
- xax/utils/experiments.py,sha256=Hzl46_9IH5_9cKzxit-FyVUWBH-_lBs00ZciuIdnWO8,29811
48
- xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
47
+ xax/utils/experiments.py,sha256=d2H63ECtVOKySMUMrQRqq4kcuZpoXqo-L931usDVAhE,29903
48
+ xax/utils/jax.py,sha256=KQYUHjN6t6JIWa11aRSO3edcsAgTscw_dExxI6kCd9g,6767
49
49
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
50
50
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
51
51
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.2.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.2.dist-info/METADATA,sha256=Ku0h6R6WToJ4rMYhcswGLXtIGVtzouWIGelHZFW30IM,1882
63
- xax-0.2.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.2.2.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.2.dist-info/RECORD,,
61
+ xax-0.2.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.3.dist-info/METADATA,sha256=ukAnG444wnzRpXgmHSrs7RKJ-UQvOdl6ZE2ZrN0w4Yg,1882
63
+ xax-0.2.3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.2.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.3.dist-info/RECORD,,
File without changes