xax 0.3.8__tar.gz → 0.3.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {xax-0.3.8/xax.egg-info → xax-0.3.10}/PKG-INFO +1 -1
  2. {xax-0.3.8 → xax-0.3.10}/xax/__init__.py +1 -1
  3. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/data_loader.py +7 -2
  4. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/train.py +45 -57
  5. {xax-0.3.8 → xax-0.3.10}/xax/utils/pytree.py +11 -4
  6. {xax-0.3.8 → xax-0.3.10/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.3.8 → xax-0.3.10}/LICENSE +0 -0
  8. {xax-0.3.8 → xax-0.3.10}/MANIFEST.in +0 -0
  9. {xax-0.3.8 → xax-0.3.10}/README.md +0 -0
  10. {xax-0.3.8 → xax-0.3.10}/pyproject.toml +0 -0
  11. {xax-0.3.8 → xax-0.3.10}/setup.cfg +0 -0
  12. {xax-0.3.8 → xax-0.3.10}/setup.py +0 -0
  13. {xax-0.3.8 → xax-0.3.10}/xax/cli/__init__.py +0 -0
  14. {xax-0.3.8 → xax-0.3.10}/xax/cli/edit_config.py +0 -0
  15. {xax-0.3.8 → xax-0.3.10}/xax/core/__init__.py +0 -0
  16. {xax-0.3.8 → xax-0.3.10}/xax/core/conf.py +0 -0
  17. {xax-0.3.8 → xax-0.3.10}/xax/core/state.py +0 -0
  18. {xax-0.3.8 → xax-0.3.10}/xax/nn/__init__.py +0 -0
  19. {xax-0.3.8 → xax-0.3.10}/xax/nn/attention.py +0 -0
  20. {xax-0.3.8 → xax-0.3.10}/xax/nn/distributions.py +0 -0
  21. {xax-0.3.8 → xax-0.3.10}/xax/nn/embeddings.py +0 -0
  22. {xax-0.3.8 → xax-0.3.10}/xax/nn/functions.py +0 -0
  23. {xax-0.3.8 → xax-0.3.10}/xax/nn/geom.py +0 -0
  24. {xax-0.3.8 → xax-0.3.10}/xax/nn/losses.py +0 -0
  25. {xax-0.3.8 → xax-0.3.10}/xax/nn/metrics.py +0 -0
  26. {xax-0.3.8 → xax-0.3.10}/xax/nn/parallel.py +0 -0
  27. {xax-0.3.8 → xax-0.3.10}/xax/nn/ssm.py +0 -0
  28. {xax-0.3.8 → xax-0.3.10}/xax/py.typed +0 -0
  29. {xax-0.3.8 → xax-0.3.10}/xax/requirements-dev.txt +0 -0
  30. {xax-0.3.8 → xax-0.3.10}/xax/requirements.txt +0 -0
  31. {xax-0.3.8 → xax-0.3.10}/xax/task/__init__.py +0 -0
  32. {xax-0.3.8 → xax-0.3.10}/xax/task/base.py +0 -0
  33. {xax-0.3.8 → xax-0.3.10}/xax/task/launchers/__init__.py +0 -0
  34. {xax-0.3.8 → xax-0.3.10}/xax/task/launchers/base.py +0 -0
  35. {xax-0.3.8 → xax-0.3.10}/xax/task/launchers/cli.py +0 -0
  36. {xax-0.3.8 → xax-0.3.10}/xax/task/launchers/single_process.py +0 -0
  37. {xax-0.3.8 → xax-0.3.10}/xax/task/logger.py +0 -0
  38. {xax-0.3.8 → xax-0.3.10}/xax/task/loggers/__init__.py +0 -0
  39. {xax-0.3.8 → xax-0.3.10}/xax/task/loggers/callback.py +0 -0
  40. {xax-0.3.8 → xax-0.3.10}/xax/task/loggers/json.py +0 -0
  41. {xax-0.3.8 → xax-0.3.10}/xax/task/loggers/state.py +0 -0
  42. {xax-0.3.8 → xax-0.3.10}/xax/task/loggers/stdout.py +0 -0
  43. {xax-0.3.8 → xax-0.3.10}/xax/task/loggers/tensorboard.py +0 -0
  44. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/__init__.py +0 -0
  45. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/artifacts.py +0 -0
  46. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/checkpointing.py +0 -0
  47. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/compile.py +0 -0
  48. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/cpu_stats.py +0 -0
  49. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/gpu_stats.py +0 -0
  50. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/logger.py +0 -0
  51. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/process.py +0 -0
  52. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/runnable.py +0 -0
  53. {xax-0.3.8 → xax-0.3.10}/xax/task/mixins/step_wrapper.py +0 -0
  54. {xax-0.3.8 → xax-0.3.10}/xax/task/script.py +0 -0
  55. {xax-0.3.8 → xax-0.3.10}/xax/task/task.py +0 -0
  56. {xax-0.3.8 → xax-0.3.10}/xax/utils/__init__.py +0 -0
  57. {xax-0.3.8 → xax-0.3.10}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.3.8 → xax-0.3.10}/xax/utils/data/collate.py +0 -0
  59. {xax-0.3.8 → xax-0.3.10}/xax/utils/debugging.py +0 -0
  60. {xax-0.3.8 → xax-0.3.10}/xax/utils/experiments.py +0 -0
  61. {xax-0.3.8 → xax-0.3.10}/xax/utils/jax.py +0 -0
  62. {xax-0.3.8 → xax-0.3.10}/xax/utils/jaxpr.py +0 -0
  63. {xax-0.3.8 → xax-0.3.10}/xax/utils/logging.py +0 -0
  64. {xax-0.3.8 → xax-0.3.10}/xax/utils/numpy.py +0 -0
  65. {xax-0.3.8 → xax-0.3.10}/xax/utils/profile.py +0 -0
  66. {xax-0.3.8 → xax-0.3.10}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.3.8 → xax-0.3.10}/xax/utils/text.py +0 -0
  68. {xax-0.3.8 → xax-0.3.10}/xax/utils/types/__init__.py +0 -0
  69. {xax-0.3.8 → xax-0.3.10}/xax/utils/types/frozen_dict.py +0 -0
  70. {xax-0.3.8 → xax-0.3.10}/xax/utils/types/hashable_array.py +0 -0
  71. {xax-0.3.8 → xax-0.3.10}/xax.egg-info/SOURCES.txt +0 -0
  72. {xax-0.3.8 → xax-0.3.10}/xax.egg-info/dependency_links.txt +0 -0
  73. {xax-0.3.8 → xax-0.3.10}/xax.egg-info/entry_points.txt +0 -0
  74. {xax-0.3.8 → xax-0.3.10}/xax.egg-info/requires.txt +0 -0
  75. {xax-0.3.8 → xax-0.3.10}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.8
3
+ Version: 0.3.10
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.8"
15
+ __version__ = "0.3.10"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -110,7 +110,12 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
110
110
  "or `get_data_iterator` to return an iterator for the given dataset."
111
111
  )
112
112
 
113
- def get_dataloader(self, dataset: Dataset[T, Tc_co], phase: Phase) -> Dataloader[T, Tc_co]:
113
+ def get_dataloader(
114
+ self,
115
+ dataset: Dataset[T, Tc_co],
116
+ phase: Phase,
117
+ prefetch_factor: int | None = None,
118
+ ) -> Dataloader[T, Tc_co]:
114
119
  debugging = self.config.debug_dataloader
115
120
  if debugging:
116
121
  logger.warning("Parallel dataloaders disabled in debugging mode")
@@ -135,7 +140,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
135
140
  dataset=dataset,
136
141
  batch_size=self.config.batch_size,
137
142
  num_workers=0 if debugging else cfg.num_workers,
138
- prefetch_factor=cfg.prefetch_factor,
143
+ prefetch_factor=cfg.prefetch_factor if prefetch_factor is None else prefetch_factor,
139
144
  mp_manager=self.multiprocessing_manager,
140
145
  dataloader_worker_init_fn=self.dataloader_worker_init_fn,
141
146
  collate_worker_init_fn=self.collate_worker_init_fn,
@@ -60,7 +60,7 @@ from xax.utils.experiments import (
60
60
  get_state_file_string,
61
61
  get_training_code,
62
62
  )
63
- from xax.utils.jax import jit as xax_jit
63
+ from xax.utils.jax import jit as xax_jit, scan as xax_scan
64
64
  from xax.utils.logging import LOG_PING, LOG_STATUS
65
65
  from xax.utils.pytree import get_pytree_param_count
66
66
  from xax.utils.text import highlight_exception_message, show_info
@@ -175,6 +175,7 @@ class TrainConfig(
175
175
  valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
176
176
  max_steps: int | None = field(None, help="Maximum number of steps to run")
177
177
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
178
+ updates_per_step: int = field(1, help="Number of updates to perform per step")
178
179
  random_seed: int = field(1337, help="Random seed for the task")
179
180
  global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
180
181
 
@@ -597,6 +598,7 @@ class TrainMixin(
597
598
  metrics = self.compute_metrics(model, batch, output, loss, state)
598
599
  return loss, (output, metrics)
599
600
 
601
+ @xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
600
602
  def update(
601
603
  self,
602
604
  model_arr: PyTree,
@@ -609,44 +611,9 @@ class TrainMixin(
609
611
  grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
610
612
  grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
611
613
  grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
612
- model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
613
- return model_arr, opt_state, output, metrics | grad_metrics
614
-
615
- @xax_jit(static_argnames=["self", "optimizer"], jit_level=3)
616
- def apply_gradients_with_clipping(
617
- self,
618
- model_arr: PyTree,
619
- grads: PyTree,
620
- optimizer: optax.GradientTransformation,
621
- opt_state: optax.OptState,
622
- ) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
623
- grad_norm = optax.global_norm(grads)
624
- grad_metrics = {"grad_norm": grad_norm}
625
-
626
- def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
627
- # Clip gradients based on global norm, similar to optax.clip_by_global_norm
628
- trigger = jnp.squeeze(grad_norm < self.config.global_grad_clip)
629
-
630
- def clip_fn(t: Array) -> Array:
631
- return jax.lax.select(trigger, t, (t / grad_norm.astype(t.dtype)) * self.config.global_grad_clip)
632
-
633
- grads = jax.tree.map(clip_fn, grads)
634
-
635
- # Apply the gradient updates.
636
- updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
637
- new_model_arr = eqx.apply_updates(model_arr, updates)
638
- return new_model_arr, new_opt_state
639
-
640
- # Don't apply updates if the gradient is NaN or Inf.
641
- new_model_arr, new_opt_state = jax.lax.cond(
642
- jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
643
- lambda *_: (model_arr, opt_state),
644
- apply,
645
- grads,
646
- grad_norm,
647
- )
648
-
649
- return new_model_arr, new_opt_state, grad_metrics
614
+ updates, opt_state = optimizer.update(grads, opt_state, model_arr)
615
+ model_arr = eqx.apply_updates(model_arr, updates)
616
+ return model_arr, opt_state, output, metrics
650
617
 
651
618
  def get_size_of_batch(self, batch: Batch) -> int | None:
652
619
  """Gets the batch size for the current batch.
@@ -729,11 +696,36 @@ class TrainMixin(
729
696
  model_static: PyTree,
730
697
  optimizer: optax.GradientTransformation,
731
698
  opt_state: optax.OptState,
732
- batch: Batch,
699
+ batches: Batch,
733
700
  state: State,
734
701
  ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
735
- model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
736
- return model_arr, opt_state, output, FrozenDict(metrics)
702
+ def update_fn(
703
+ carry: tuple[PyTree, optax.OptState],
704
+ batch: Batch,
705
+ ) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
706
+ model_arr, opt_state = carry
707
+ model_arr, opt_state, output, metrics = self.update(
708
+ model_arr,
709
+ model_static,
710
+ optimizer,
711
+ opt_state,
712
+ batch,
713
+ state,
714
+ )
715
+ return (model_arr, opt_state), (output, FrozenDict(metrics))
716
+
717
+ (model_arr, opt_state), (output, metrics) = xax_scan(
718
+ update_fn,
719
+ (model_arr, opt_state),
720
+ batches,
721
+ jit_level=3,
722
+ )
723
+
724
+ # Only get the final output and metrics.
725
+ output = jax.tree.map(lambda x: x[-1], output)
726
+ metrics = jax.tree.map(lambda x: x[-1], metrics)
727
+
728
+ return model_arr, opt_state, output, metrics
737
729
 
738
730
  @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
739
731
  def val_step(
@@ -775,40 +767,36 @@ class TrainMixin(
775
767
  output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
776
768
  self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
777
769
 
778
- state = state.replace(
779
- num_steps=state.num_steps + 1,
780
- num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
781
- )
782
-
783
770
  state = state.replace(
771
+ num_steps=state.num_steps + 1,
772
+ num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
784
773
  elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
785
774
  )
786
775
 
787
776
  with ContextTimer() as timer:
788
777
  state = self.on_step_start(state)
789
778
  state = state.replace(phase="train")
790
- train_batch = next(train_pf)
779
+ train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
791
780
  model_arr, opt_state, output, metrics = self.train_step(
792
781
  model_arr=model_arr,
793
782
  model_static=model_static,
794
783
  optimizer=optimizer,
795
784
  opt_state=opt_state,
796
- batch=train_batch,
785
+ batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
797
786
  state=state,
798
787
  )
799
- self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
800
-
801
- state = state.replace(
802
- num_steps=state.num_steps + 1,
803
- num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
804
- )
805
-
788
+ self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
806
789
  state = self.on_step_end(state)
807
790
 
808
791
  state = state.replace(
792
+ num_steps=state.num_steps + 1,
793
+ num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
809
794
  elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
810
795
  )
811
796
 
797
+ if state.num_steps <= 3:
798
+ logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
799
+
812
800
  if self.should_checkpoint(state):
813
801
  model = eqx.combine(model_arr, model_static)
814
802
  self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
@@ -827,7 +815,7 @@ class TrainMixin(
827
815
  pass
828
816
 
829
817
  train_ds = self.get_dataset("train")
830
- train_dl = self.get_dataloader(train_ds, "train")
818
+ train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
831
819
  train_pf = self.get_prefetcher(train_dl)
832
820
 
833
821
  try:
@@ -1,6 +1,6 @@
1
1
  """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
2
 
3
- from typing import TypeVar
3
+ from typing import Mapping, Sequence, TypeVar
4
4
 
5
5
  import chex
6
6
  import equinox as eqx
@@ -258,11 +258,18 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
258
258
  def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
259
259
  leaves: dict[str, Array] = {}
260
260
 
261
+ def _get_str(thing: PyTree) -> str:
262
+ if isinstance(thing, str):
263
+ return thing
264
+ if isinstance(thing, Sequence):
265
+ return "/".join(_get_str(x) for x in thing)
266
+ if isinstance(thing, Mapping):
267
+ return "/".join(f"{_get_str(k)}:{_get_str(v)}" for k, v in thing.items())
268
+ return str(thing)
269
+
261
270
  def _get_leaf(path: tuple, x: PyTree) -> None:
262
271
  if isinstance(x, jnp.ndarray):
263
- # Convert path tuple to string, e.g. (1, 'a', 2) -> '1/a/2'
264
- path_str = "/".join(str(p) for p in path)
265
- leaves[path_str] = x
272
+ leaves[_get_str(path)] = x
266
273
 
267
274
  jax.tree.map_with_path(_get_leaf, pytree)
268
275
  return leaves
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.8
3
+ Version: 0.3.10
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes