xax 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/task/mixins/data_loader.py +7 -2
- xax/task/mixins/train.py +45 -57
- xax/utils/pytree.py +11 -4
- {xax-0.3.8.dist-info → xax-0.3.10.dist-info}/METADATA +1 -1
- {xax-0.3.8.dist-info → xax-0.3.10.dist-info}/RECORD +10 -10
- {xax-0.3.8.dist-info → xax-0.3.10.dist-info}/WHEEL +0 -0
- {xax-0.3.8.dist-info → xax-0.3.10.dist-info}/entry_points.txt +0 -0
- {xax-0.3.8.dist-info → xax-0.3.10.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.8.dist-info → xax-0.3.10.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/task/mixins/data_loader.py
CHANGED
@@ -110,7 +110,12 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
110
110
|
"or `get_data_iterator` to return an iterator for the given dataset."
|
111
111
|
)
|
112
112
|
|
113
|
-
def get_dataloader(
|
113
|
+
def get_dataloader(
|
114
|
+
self,
|
115
|
+
dataset: Dataset[T, Tc_co],
|
116
|
+
phase: Phase,
|
117
|
+
prefetch_factor: int | None = None,
|
118
|
+
) -> Dataloader[T, Tc_co]:
|
114
119
|
debugging = self.config.debug_dataloader
|
115
120
|
if debugging:
|
116
121
|
logger.warning("Parallel dataloaders disabled in debugging mode")
|
@@ -135,7 +140,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
135
140
|
dataset=dataset,
|
136
141
|
batch_size=self.config.batch_size,
|
137
142
|
num_workers=0 if debugging else cfg.num_workers,
|
138
|
-
prefetch_factor=cfg.prefetch_factor,
|
143
|
+
prefetch_factor=cfg.prefetch_factor if prefetch_factor is None else prefetch_factor,
|
139
144
|
mp_manager=self.multiprocessing_manager,
|
140
145
|
dataloader_worker_init_fn=self.dataloader_worker_init_fn,
|
141
146
|
collate_worker_init_fn=self.collate_worker_init_fn,
|
xax/task/mixins/train.py
CHANGED
@@ -60,7 +60,7 @@ from xax.utils.experiments import (
|
|
60
60
|
get_state_file_string,
|
61
61
|
get_training_code,
|
62
62
|
)
|
63
|
-
from xax.utils.jax import jit as xax_jit
|
63
|
+
from xax.utils.jax import jit as xax_jit, scan as xax_scan
|
64
64
|
from xax.utils.logging import LOG_PING, LOG_STATUS
|
65
65
|
from xax.utils.pytree import get_pytree_param_count
|
66
66
|
from xax.utils.text import highlight_exception_message, show_info
|
@@ -175,6 +175,7 @@ class TrainConfig(
|
|
175
175
|
valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
|
176
176
|
max_steps: int | None = field(None, help="Maximum number of steps to run")
|
177
177
|
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
178
|
+
updates_per_step: int = field(1, help="Number of updates to perform per step")
|
178
179
|
random_seed: int = field(1337, help="Random seed for the task")
|
179
180
|
global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
|
180
181
|
|
@@ -597,6 +598,7 @@ class TrainMixin(
|
|
597
598
|
metrics = self.compute_metrics(model, batch, output, loss, state)
|
598
599
|
return loss, (output, metrics)
|
599
600
|
|
601
|
+
@xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
|
600
602
|
def update(
|
601
603
|
self,
|
602
604
|
model_arr: PyTree,
|
@@ -609,44 +611,9 @@ class TrainMixin(
|
|
609
611
|
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
610
612
|
grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
|
611
613
|
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
@xax_jit(static_argnames=["self", "optimizer"], jit_level=3)
|
616
|
-
def apply_gradients_with_clipping(
|
617
|
-
self,
|
618
|
-
model_arr: PyTree,
|
619
|
-
grads: PyTree,
|
620
|
-
optimizer: optax.GradientTransformation,
|
621
|
-
opt_state: optax.OptState,
|
622
|
-
) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
|
623
|
-
grad_norm = optax.global_norm(grads)
|
624
|
-
grad_metrics = {"grad_norm": grad_norm}
|
625
|
-
|
626
|
-
def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
|
627
|
-
# Clip gradients based on global norm, similar to optax.clip_by_global_norm
|
628
|
-
trigger = jnp.squeeze(grad_norm < self.config.global_grad_clip)
|
629
|
-
|
630
|
-
def clip_fn(t: Array) -> Array:
|
631
|
-
return jax.lax.select(trigger, t, (t / grad_norm.astype(t.dtype)) * self.config.global_grad_clip)
|
632
|
-
|
633
|
-
grads = jax.tree.map(clip_fn, grads)
|
634
|
-
|
635
|
-
# Apply the gradient updates.
|
636
|
-
updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
|
637
|
-
new_model_arr = eqx.apply_updates(model_arr, updates)
|
638
|
-
return new_model_arr, new_opt_state
|
639
|
-
|
640
|
-
# Don't apply updates if the gradient is NaN or Inf.
|
641
|
-
new_model_arr, new_opt_state = jax.lax.cond(
|
642
|
-
jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
|
643
|
-
lambda *_: (model_arr, opt_state),
|
644
|
-
apply,
|
645
|
-
grads,
|
646
|
-
grad_norm,
|
647
|
-
)
|
648
|
-
|
649
|
-
return new_model_arr, new_opt_state, grad_metrics
|
614
|
+
updates, opt_state = optimizer.update(grads, opt_state, model_arr)
|
615
|
+
model_arr = eqx.apply_updates(model_arr, updates)
|
616
|
+
return model_arr, opt_state, output, metrics
|
650
617
|
|
651
618
|
def get_size_of_batch(self, batch: Batch) -> int | None:
|
652
619
|
"""Gets the batch size for the current batch.
|
@@ -729,11 +696,36 @@ class TrainMixin(
|
|
729
696
|
model_static: PyTree,
|
730
697
|
optimizer: optax.GradientTransformation,
|
731
698
|
opt_state: optax.OptState,
|
732
|
-
|
699
|
+
batches: Batch,
|
733
700
|
state: State,
|
734
701
|
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
735
|
-
|
736
|
-
|
702
|
+
def update_fn(
|
703
|
+
carry: tuple[PyTree, optax.OptState],
|
704
|
+
batch: Batch,
|
705
|
+
) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
|
706
|
+
model_arr, opt_state = carry
|
707
|
+
model_arr, opt_state, output, metrics = self.update(
|
708
|
+
model_arr,
|
709
|
+
model_static,
|
710
|
+
optimizer,
|
711
|
+
opt_state,
|
712
|
+
batch,
|
713
|
+
state,
|
714
|
+
)
|
715
|
+
return (model_arr, opt_state), (output, FrozenDict(metrics))
|
716
|
+
|
717
|
+
(model_arr, opt_state), (output, metrics) = xax_scan(
|
718
|
+
update_fn,
|
719
|
+
(model_arr, opt_state),
|
720
|
+
batches,
|
721
|
+
jit_level=3,
|
722
|
+
)
|
723
|
+
|
724
|
+
# Only get the final output and metrics.
|
725
|
+
output = jax.tree.map(lambda x: x[-1], output)
|
726
|
+
metrics = jax.tree.map(lambda x: x[-1], metrics)
|
727
|
+
|
728
|
+
return model_arr, opt_state, output, metrics
|
737
729
|
|
738
730
|
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
739
731
|
def val_step(
|
@@ -775,40 +767,36 @@ class TrainMixin(
|
|
775
767
|
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
776
768
|
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
777
769
|
|
778
|
-
state = state.replace(
|
779
|
-
num_steps=state.num_steps + 1,
|
780
|
-
num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
|
781
|
-
)
|
782
|
-
|
783
770
|
state = state.replace(
|
771
|
+
num_steps=state.num_steps + 1,
|
772
|
+
num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
|
784
773
|
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
785
774
|
)
|
786
775
|
|
787
776
|
with ContextTimer() as timer:
|
788
777
|
state = self.on_step_start(state)
|
789
778
|
state = state.replace(phase="train")
|
790
|
-
|
779
|
+
train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
|
791
780
|
model_arr, opt_state, output, metrics = self.train_step(
|
792
781
|
model_arr=model_arr,
|
793
782
|
model_static=model_static,
|
794
783
|
optimizer=optimizer,
|
795
784
|
opt_state=opt_state,
|
796
|
-
|
785
|
+
batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
|
797
786
|
state=state,
|
798
787
|
)
|
799
|
-
self.log_step(eqx.combine(model_arr, model_static),
|
800
|
-
|
801
|
-
state = state.replace(
|
802
|
-
num_steps=state.num_steps + 1,
|
803
|
-
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
804
|
-
)
|
805
|
-
|
788
|
+
self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
|
806
789
|
state = self.on_step_end(state)
|
807
790
|
|
808
791
|
state = state.replace(
|
792
|
+
num_steps=state.num_steps + 1,
|
793
|
+
num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
|
809
794
|
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
810
795
|
)
|
811
796
|
|
797
|
+
if state.num_steps <= 3:
|
798
|
+
logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
|
799
|
+
|
812
800
|
if self.should_checkpoint(state):
|
813
801
|
model = eqx.combine(model_arr, model_static)
|
814
802
|
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
@@ -827,7 +815,7 @@ class TrainMixin(
|
|
827
815
|
pass
|
828
816
|
|
829
817
|
train_ds = self.get_dataset("train")
|
830
|
-
train_dl = self.get_dataloader(train_ds, "train")
|
818
|
+
train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
|
831
819
|
train_pf = self.get_prefetcher(train_dl)
|
832
820
|
|
833
821
|
try:
|
xax/utils/pytree.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
2
|
|
3
|
-
from typing import TypeVar
|
3
|
+
from typing import Mapping, Sequence, TypeVar
|
4
4
|
|
5
5
|
import chex
|
6
6
|
import equinox as eqx
|
@@ -258,11 +258,18 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
|
|
258
258
|
def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
|
259
259
|
leaves: dict[str, Array] = {}
|
260
260
|
|
261
|
+
def _get_str(thing: PyTree) -> str:
|
262
|
+
if isinstance(thing, str):
|
263
|
+
return thing
|
264
|
+
if isinstance(thing, Sequence):
|
265
|
+
return "/".join(_get_str(x) for x in thing)
|
266
|
+
if isinstance(thing, Mapping):
|
267
|
+
return "/".join(f"{_get_str(k)}:{_get_str(v)}" for k, v in thing.items())
|
268
|
+
return str(thing)
|
269
|
+
|
261
270
|
def _get_leaf(path: tuple, x: PyTree) -> None:
|
262
271
|
if isinstance(x, jnp.ndarray):
|
263
|
-
|
264
|
-
path_str = "/".join(str(p) for p in path)
|
265
|
-
leaves[path_str] = x
|
272
|
+
leaves[_get_str(path)] = x
|
266
273
|
|
267
274
|
jax.tree.map_with_path(_get_leaf, pytree)
|
268
275
|
return leaves
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=lSwyrPTof_BZ-pyPNhNICJnCZMN9i2sJ-Ii3S_vY_28,16666
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -37,13 +37,13 @@ xax/task/mixins/artifacts.py,sha256=R-y3p7__zJHlHDqwDVAZysg2ZmebCJbqAx_xGT2Xpd0,
|
|
37
37
|
xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
|
38
38
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
39
39
|
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
40
|
-
xax/task/mixins/data_loader.py,sha256=
|
40
|
+
xax/task/mixins/data_loader.py,sha256=BKfOVWXR70vbyHMFlnlUiQQHXHH5zTj5WtmsymNCFB4,6722
|
41
41
|
xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
|
42
42
|
xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
|
43
43
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
44
44
|
xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
|
45
45
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
46
|
-
xax/task/mixins/train.py,sha256=
|
46
|
+
xax/task/mixins/train.py,sha256=_kDpifLi1arSuT0ssFhBV0axpvLlQG3a97pohya0Eqc,32908
|
47
47
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
48
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
49
49
|
xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
|
@@ -52,7 +52,7 @@ xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
|
52
52
|
xax/utils/logging.py,sha256=Kkyma_LJXqrN2HTQ214gRP_9ih3_bKk115MWC60lQWM,6656
|
53
53
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
54
54
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
55
|
-
xax/utils/pytree.py,sha256=
|
55
|
+
xax/utils/pytree.py,sha256=w8Ab2LmJdQ8e1FxKF0xWWaOak09Mhu44ZcOeUR6uGFA,9889
|
56
56
|
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
57
57
|
xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
|
58
58
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -60,9 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
60
60
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
61
61
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
62
62
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
68
|
-
xax-0.3.
|
63
|
+
xax-0.3.10.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
64
|
+
xax-0.3.10.dist-info/METADATA,sha256=oQMGYjsfYxMmw0A60qE15yda_G-0YG5RNl17tboR1f0,1247
|
65
|
+
xax-0.3.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
66
|
+
xax-0.3.10.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
67
|
+
xax-0.3.10.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
68
|
+
xax-0.3.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|