xax 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.8"
15
+ __version__ = "0.1.10"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -23,7 +23,6 @@ __all__ = [
23
23
  "get_run_dir",
24
24
  "load_user_config",
25
25
  "State",
26
- "cast_phase",
27
26
  "FourierEmbeddings",
28
27
  "IdentityPositionalEmbeddings",
29
28
  "LearnedPositionalEmbeddings",
@@ -41,9 +40,6 @@ __all__ = [
41
40
  "load_eqx_mlp",
42
41
  "make_eqx_mlp",
43
42
  "save_eqx",
44
- "export",
45
- "export_flax",
46
- "export_with_params",
47
43
  "euler_to_quat",
48
44
  "get_projected_gravity_vector_from_quat",
49
45
  "quat_to_euler",
@@ -180,7 +176,6 @@ NAME_MAP: dict[str, str] = {
180
176
  "get_run_dir": "core.conf",
181
177
  "load_user_config": "core.conf",
182
178
  "State": "core.state",
183
- "cast_phase": "core.state",
184
179
  "FourierEmbeddings": "nn.embeddings",
185
180
  "IdentityPositionalEmbeddings": "nn.embeddings",
186
181
  "LearnedPositionalEmbeddings": "nn.embeddings",
@@ -198,9 +193,6 @@ NAME_MAP: dict[str, str] = {
198
193
  "load_eqx_mlp": "nn.equinox",
199
194
  "make_eqx_mlp": "nn.equinox",
200
195
  "save_eqx": "nn.equinox",
201
- "export": "nn.export",
202
- "export_flax": "nn.export",
203
- "export_with_params": "nn.export",
204
196
  "euler_to_quat": "nn.geom",
205
197
  "get_projected_gravity_vector_from_quat": "nn.geom",
206
198
  "quat_to_euler": "nn.geom",
@@ -329,7 +321,7 @@ if IMPORT_ALL or TYPE_CHECKING:
329
321
  get_run_dir,
330
322
  load_user_config,
331
323
  )
332
- from xax.core.state import Phase, State, cast_phase
324
+ from xax.core.state import Phase, State
333
325
  from xax.nn.embeddings import (
334
326
  EmbeddingKind,
335
327
  FourierEmbeddings,
@@ -354,11 +346,6 @@ if IMPORT_ALL or TYPE_CHECKING:
354
346
  make_eqx_mlp,
355
347
  save_eqx,
356
348
  )
357
- from xax.nn.export import (
358
- export,
359
- export_flax,
360
- export_with_params,
361
- )
362
349
  from xax.nn.geom import (
363
350
  euler_to_quat,
364
351
  get_projected_gravity_vector_from_quat,
xax/core/state.py CHANGED
@@ -1,9 +1,10 @@
1
1
  """Defines a dataclass for keeping track of the current training state."""
2
2
 
3
3
  import time
4
- from dataclasses import dataclass
5
- from typing import Literal, NotRequired, TypedDict, cast, get_args
4
+ from dataclasses import asdict, dataclass
5
+ from typing import Any, Literal, NotRequired, TypedDict, Unpack, cast
6
6
 
7
+ import jax
7
8
  from omegaconf import MISSING
8
9
 
9
10
  from xax.core.conf import field
@@ -11,12 +12,6 @@ from xax.core.conf import field
11
12
  Phase = Literal["train", "valid"]
12
13
 
13
14
 
14
- def cast_phase(raw_phase: str) -> Phase:
15
- args = get_args(Phase)
16
- assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}"
17
- return cast(Phase, raw_phase)
18
-
19
-
20
15
  class StateDict(TypedDict, total=False):
21
16
  num_steps: NotRequired[int]
22
17
  num_samples: NotRequired[int]
@@ -24,10 +19,11 @@ class StateDict(TypedDict, total=False):
24
19
  num_valid_samples: NotRequired[int]
25
20
  start_time_s: NotRequired[float]
26
21
  elapsed_time_s: NotRequired[float]
27
- raw_phase: NotRequired[str]
22
+ phase: NotRequired[Phase]
28
23
 
29
24
 
30
- @dataclass
25
+ @jax.tree_util.register_dataclass
26
+ @dataclass(frozen=True, kw_only=True)
31
27
  class State:
32
28
  num_steps: int = field(MISSING, help="Number of steps so far")
33
29
  num_samples: int = field(MISSING, help="Number of sample so far")
@@ -35,15 +31,11 @@ class State:
35
31
  num_valid_samples: int = field(MISSING, help="Number of validation samples so far")
36
32
  start_time_s: float = field(MISSING, help="Start time of training")
37
33
  elapsed_time_s: float = field(MISSING, help="Total elapsed time so far")
38
- raw_phase: str = field(MISSING, help="Current training phase")
34
+ _phase: int = field(MISSING, help="Current training phase")
39
35
 
40
36
  @property
41
37
  def phase(self) -> Phase:
42
- return cast_phase(self.raw_phase)
43
-
44
- @phase.setter
45
- def phase(self, phase: Phase) -> None:
46
- self.raw_phase = phase
38
+ return cast(Phase, ["train", "valid"][self._phase])
47
39
 
48
40
  @classmethod
49
41
  def init_state(cls) -> "State":
@@ -54,7 +46,7 @@ class State:
54
46
  num_valid_samples=0,
55
47
  start_time_s=time.time(),
56
48
  elapsed_time_s=0.0,
57
- raw_phase="train",
49
+ _phase=0,
58
50
  )
59
51
 
60
52
  @property
@@ -69,3 +61,16 @@ class State:
69
61
  return self.num_valid_steps
70
62
  case _:
71
63
  raise ValueError(f"Invalid phase: {phase}")
64
+
65
+ def replace(self, **kwargs: Unpack[StateDict]) -> "State":
66
+ extra_kwargs: dict[str, Any] = {} # noqa: ANN401
67
+ if "phase" in kwargs:
68
+ phase = kwargs.pop("phase")
69
+ match phase:
70
+ case "train":
71
+ extra_kwargs["_phase"] = 0
72
+ case "valid":
73
+ extra_kwargs["_phase"] = 1
74
+ case _:
75
+ raise ValueError(f"Invalid phase: {phase}")
76
+ return State(**{**asdict(self), **kwargs, **extra_kwargs})
xax/task/base.py CHANGED
@@ -16,7 +16,8 @@ from types import TracebackType
16
16
  from typing import Generic, Self, TypeVar, cast
17
17
 
18
18
  import jax
19
- from omegaconf import Container, DictConfig, OmegaConf
19
+ from omegaconf import DictConfig, OmegaConf
20
+ from omegaconf.base import SCMode
20
21
 
21
22
  from xax.core.state import State
22
23
  from xax.utils.text import camelcase_to_snakecase
@@ -66,9 +67,6 @@ class BaseTask(Generic[Config]):
66
67
 
67
68
  self.config = config
68
69
 
69
- if isinstance(self.config, Container):
70
- OmegaConf.resolve(self.config)
71
-
72
70
  def on_step_start(self, state: State) -> State:
73
71
  return state
74
72
 
@@ -195,7 +193,15 @@ class BaseTask(Generic[Config]):
195
193
  cfg = OmegaConf.merge(cfg, *(get_config(path, task_path) for path in paths))
196
194
  cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(non_paths))
197
195
 
198
- return cast(Config, cfg)
196
+ return cast(
197
+ Config,
198
+ OmegaConf.to_container(
199
+ cfg,
200
+ resolve=True,
201
+ throw_on_missing=True,
202
+ structured_config_mode=SCMode.INSTANTIATE,
203
+ ),
204
+ )
199
205
 
200
206
  @classmethod
201
207
  def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
xax/task/mixins/train.py CHANGED
@@ -53,8 +53,10 @@ from xax.utils.experiments import (
53
53
  get_packages_with_versions,
54
54
  get_training_code,
55
55
  )
56
+ from xax.utils.jax import jit as xax_jit
56
57
  from xax.utils.logging import LOG_STATUS
57
58
  from xax.utils.text import highlight_exception_message, show_info
59
+ from xax.utils.types.frozen_dict import FrozenDict
58
60
 
59
61
  logger = logging.getLogger(__name__)
60
62
 
@@ -212,32 +214,31 @@ class TrainMixin(
212
214
 
213
215
  def on_step_end(self, state: State) -> State:
214
216
  state = super().on_step_end(state)
215
- state.elapsed_time_s = time.time() - state.start_time_s
216
- return state
217
+ return state.replace(elapsed_time_s=time.time() - state.start_time_s)
217
218
 
218
- def log_train_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
219
+ def log_train_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
219
220
  """Override this function to do logging during the training phase.
220
221
 
221
222
  This function is called after the model forward pass and before the
222
223
  backward pass. It is called in the training phase.
223
224
 
224
225
  Args:
225
- model: The current model.
226
226
  batch: The batch from the dataloader.
227
227
  output: The model output.
228
+ metrics: The metrics for the current batch.
228
229
  state: The current training state.
229
230
  """
230
231
 
231
- def log_valid_step(self, model: PyTree, batch: Batch, output: Output, state: State) -> None:
232
+ def log_valid_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
232
233
  """Override this function to do logging during the validation phase.
233
234
 
234
235
  This function is called after the model forward pass. It is called in
235
236
  the validation phase.
236
237
 
237
238
  Args:
238
- model: The current model.
239
239
  batch: The batch from the dataloader.
240
240
  output: The model output.
241
+ metrics: The metrics for the current batch.
241
242
  state: The current training state.
242
243
  """
243
244
 
@@ -248,18 +249,23 @@ class TrainMixin(
248
249
  for k, v in d.items():
249
250
  self.logger.log_scalar(k, v, namespace=ns)
250
251
 
251
- def log_step(self, model: PyTree, batch: Batch, output: Output, loss: Array, state: State) -> None:
252
+ def log_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
252
253
  phase = state.phase
253
254
 
254
- self.logger.log_scalar("loss", loss, namespace="loss")
255
+ for k, v in metrics.items():
256
+ if v.size == 1:
257
+ self.logger.log_scalar(k, v.item())
258
+ else:
259
+ self.logger.log_histogram(k, v)
260
+
255
261
  self.log_state_timers(state)
256
262
 
257
263
  # Delegate to the appropriate logging function based on the phase.
258
264
  match phase:
259
265
  case "train":
260
- self.log_train_step(model, batch, output, state)
266
+ self.log_train_step(batch, output, metrics, state)
261
267
  case "valid":
262
- self.log_valid_step(model, batch, output, state)
268
+ self.log_valid_step(batch, output, metrics, state)
263
269
  case _:
264
270
  raise KeyError(f"Unknown phase: {phase}")
265
271
 
@@ -332,8 +338,7 @@ class TrainMixin(
332
338
 
333
339
  return model, optimizer, opt_state, state
334
340
 
335
- @eqx.filter_jit
336
- def get_output(self, model: PyTree, batch: Batch) -> Output:
341
+ def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
337
342
  """Gets the output from the model.
338
343
 
339
344
  By default, we assume the model is a function that takes the batch as
@@ -343,11 +348,11 @@ class TrainMixin(
343
348
  Args:
344
349
  model: The current model.
345
350
  batch: The current minibatch of samples.
351
+ state: The current training state.
346
352
  """
347
353
  raise NotImplementedError("`get_output` must be implemented by the subclass")
348
354
 
349
- @eqx.filter_jit
350
- def compute_loss(self, model: PyTree, batch: Batch, output: Output) -> Array:
355
+ def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
351
356
  """Gets the loss for the current batch.
352
357
 
353
358
  By default, we assume the model is a function that takes the batch as
@@ -358,6 +363,7 @@ class TrainMixin(
358
363
  model: The current model.
359
364
  batch: The current minibatch of samples.
360
365
  output: The output from the model.
366
+ state: The current training state.
361
367
 
362
368
  Returns:
363
369
  The computed loss, as a tensor.
@@ -366,24 +372,59 @@ class TrainMixin(
366
372
  raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
367
373
  return output
368
374
 
369
- @eqx.filter_jit
370
- def get_output_and_loss(self, model: PyTree, batch: Batch) -> tuple[Array, Output]:
371
- output = self.get_output(model, batch)
372
- loss = self.compute_loss(model, batch, output)
373
- return loss, output
375
+ def compute_metrics(
376
+ self,
377
+ model: PyTree,
378
+ batch: Batch,
379
+ output: Output,
380
+ loss: Array,
381
+ state: State,
382
+ ) -> dict[str, Array]:
383
+ """Computes the metrics for the current batch.
384
+
385
+ Args:
386
+ model: The current model.
387
+ batch: The current minibatch of samples.
388
+ output: The output from the model.
389
+ loss: The loss for the current batch.
390
+ state: The current training state.
391
+
392
+ Returns:
393
+ A dictionary of metrics.
394
+ """
395
+ return {
396
+ "loss": loss,
397
+ }
398
+
399
+ @xax_jit(static_argnames=["self", "model_static"])
400
+ def get_output_and_loss(
401
+ self,
402
+ model_arr: PyTree,
403
+ model_static: PyTree,
404
+ batch: Batch,
405
+ state: State,
406
+ ) -> tuple[Array, tuple[Output, FrozenDict[str, Array]]]:
407
+ model = eqx.combine(model_arr, model_static)
408
+ output = self.get_output(model, batch, state)
409
+ loss = self.compute_loss(model, batch, output, state)
410
+ metrics = self.compute_metrics(model, batch, output, loss, state)
411
+ return loss, (output, FrozenDict(metrics))
374
412
 
375
- @eqx.filter_jit
376
413
  def update(
377
414
  self,
378
- model: PyTree,
415
+ model_arr: PyTree,
416
+ model_static: PyTree,
379
417
  optimizer: optax.GradientTransformation,
380
418
  opt_state: optax.OptState,
381
419
  batch: Batch,
382
- ) -> tuple[Array, PyTree, optax.OptState, Output]:
383
- (loss, output), grads = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)(model, batch)
384
- updates, opt_state = optimizer.update(grads, opt_state, model)
385
- model = eqx.apply_updates(model, updates)
386
- return loss, model, opt_state, output
420
+ state: State,
421
+ ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
422
+ grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
423
+ grad_fn = xax_jit(static_argnums=[1])(grad_fn)
424
+ grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
425
+ updates, opt_state = optimizer.update(grads, opt_state, model_arr)
426
+ model_arr = eqx.apply_updates(model_arr, updates)
427
+ return model_arr, opt_state, output, metrics
387
428
 
388
429
  def get_size_of_batch(self, batch: Batch) -> int | None:
389
430
  """Gets the batch size for the current batch.
@@ -457,21 +498,32 @@ class TrainMixin(
457
498
  self.logger.log_file("training_code.txt", get_training_code(self))
458
499
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
459
500
 
460
- @eqx.filter_jit
501
+ def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
502
+ return eqx.is_inexact_array(item)
503
+
504
+ @xax_jit(static_argnames=["self", "model_static", "optimizer"])
461
505
  def train_step(
462
506
  self,
463
- model: PyTree,
507
+ model_arr: PyTree,
508
+ model_static: PyTree,
464
509
  optimizer: optax.GradientTransformation,
465
510
  opt_state: optax.OptState,
466
511
  batch: Batch,
467
- ) -> tuple[PyTree, optax.OptState, Array, Output]:
468
- loss, model, opt_state, output = self.update(model, optimizer, opt_state, batch)
469
- return model, opt_state, loss, output
512
+ state: State,
513
+ ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
514
+ model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
515
+ return model_arr, opt_state, output, metrics
470
516
 
471
- @eqx.filter_jit
472
- def val_step(self, model: PyTree, batch: Batch) -> tuple[PyTree, Array, Output]:
473
- loss, output = eqx.filter_jit(self.get_output_and_loss)(model, batch)
474
- return model, loss, output
517
+ @xax_jit(static_argnames=["self", "model_static"])
518
+ def val_step(
519
+ self,
520
+ model_arr: PyTree,
521
+ model_static: PyTree,
522
+ batch: Batch,
523
+ state: State,
524
+ ) -> tuple[Output, FrozenDict[str, Array]]:
525
+ _, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
526
+ return output, metrics
475
527
 
476
528
  def train_loop(
477
529
  self,
@@ -482,36 +534,46 @@ class TrainMixin(
482
534
  valid_pf: Iterator[Batch],
483
535
  state: State,
484
536
  ) -> None:
537
+ model_arr, model_static = eqx.partition(model, self.model_partition_fn)
538
+
485
539
  while not self.is_training_over(state):
486
540
  if self.valid_step_timer.is_valid_step(state):
487
541
  valid_batch = next(valid_pf)
488
- with self.step_context("model_step"):
489
- model, loss, output = self.val_step(model, valid_batch)
542
+ state = state.replace(
543
+ phase="valid",
544
+ num_valid_steps=state.num_valid_steps + 1,
545
+ num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
546
+ )
490
547
 
491
- # Perform logging.
492
- with self.step_context("write_logs"):
493
- state.phase = "valid"
494
- self.log_step(model, valid_batch, output, loss, state)
495
- state.num_valid_samples += 1
548
+ output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
549
+ self.log_step(valid_batch, output, metrics, state)
496
550
 
497
551
  state = self.on_step_start(state)
498
-
499
- with self.step_context("model_step"):
500
- train_batch = next(train_pf)
501
- model, opt_state, loss, output = self.train_step(model, optimizer, opt_state, train_batch)
502
-
503
- with self.step_context("write_logs"):
504
- state.phase = "train"
505
- self.log_step(model, train_batch, output, loss, state)
506
- state.num_steps += 1
507
- state.num_samples += self.get_size_of_batch(train_batch) or 0
552
+ train_batch = next(train_pf)
553
+ state = state.replace(
554
+ phase="train",
555
+ num_steps=state.num_steps + 1,
556
+ num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
557
+ )
558
+
559
+ model_arr, opt_state, output, metrics = self.train_step(
560
+ model_arr=model_arr,
561
+ model_static=model_static,
562
+ optimizer=optimizer,
563
+ opt_state=opt_state,
564
+ batch=train_batch,
565
+ state=state,
566
+ )
567
+ self.log_step(train_batch, output, metrics, state)
508
568
 
509
569
  state = self.on_step_end(state)
510
570
 
511
571
  if self.should_checkpoint(state):
572
+ model = eqx.combine(model_arr, model_static)
512
573
  self.save_checkpoint(model, optimizer, opt_state, state)
513
574
 
514
575
  # After finishing training, save the final checkpoint.
576
+ model = eqx.combine(model_arr, model_static)
515
577
  self.save_checkpoint(model, optimizer, opt_state, state)
516
578
 
517
579
  @contextlib.contextmanager
xax/task/script.py CHANGED
@@ -3,6 +3,8 @@
3
3
  from dataclasses import dataclass
4
4
  from typing import Generic, TypeVar
5
5
 
6
+ import jax
7
+
6
8
  from xax.task.base import BaseConfig, BaseTask
7
9
  from xax.task.mixins import (
8
10
  ArtifactsConfig,
@@ -20,6 +22,7 @@ from xax.task.mixins import (
20
22
  )
21
23
 
22
24
 
25
+ @jax.tree_util.register_dataclass
23
26
  @dataclass(kw_only=True)
24
27
  class ScriptConfig(
25
28
  CPUStatsConfig,
xax/utils/pytree.py CHANGED
@@ -31,7 +31,7 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
31
31
 
32
32
  def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
33
33
  """Get a slice of a pytree."""
34
- return jax.tree_util.tree_map(lambda x: slice_array(x, start, slice_length), pytree)
34
+ return jax.tree.map(lambda x: slice_array(x, start, slice_length), pytree)
35
35
 
36
36
 
37
37
  def flatten_array(x: Array, flatten_size: int) -> Array:
@@ -43,14 +43,14 @@ def flatten_array(x: Array, flatten_size: int) -> Array:
43
43
 
44
44
  def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
45
45
  """Flatten a pytree into a (flatten_size, ...) pytree."""
46
- return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
46
+ return jax.tree.map(lambda x: flatten_array(x, flatten_size), pytree)
47
47
 
48
48
 
49
49
  def pytree_has_nans(pytree: PyTree) -> Array:
50
50
  """Check if a pytree has any NaNs."""
51
51
  has_nans = jax.tree_util.tree_reduce(
52
52
  lambda a, b: jnp.logical_or(a, b),
53
- jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree),
53
+ jax.tree.map(lambda x: jnp.any(jnp.isnan(x)), pytree),
54
54
  )
55
55
  return has_nans
56
56
 
@@ -58,13 +58,13 @@ def pytree_has_nans(pytree: PyTree) -> Array:
58
58
  def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
59
59
  """Update a pytree based on a condition."""
60
60
  # Tricky, need use tree_map because where expects array leafs.
61
- return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), new, original)
61
+ return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
62
62
 
63
63
 
64
64
  def compute_nan_ratio(pytree: PyTree) -> Array:
65
65
  """Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
66
- nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
67
- total_counts = jax.tree_util.tree_map(lambda x: x.size, pytree)
66
+ nan_counts = jax.tree.map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
67
+ total_counts = jax.tree.map(lambda x: x.size, pytree)
68
68
 
69
69
  total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
70
70
  total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
@@ -118,7 +118,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
118
118
  # Reshape back to the original shape
119
119
  return permuted.reshape(orig_shape)
120
120
 
121
- return jax.tree_util.tree_map(permute_array, data)
121
+ return jax.tree.map(permute_array, data)
122
122
 
123
123
 
124
124
  def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
@@ -133,7 +133,7 @@ def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], r
133
133
  return x[tuple(idx_grids)]
134
134
  return x
135
135
 
136
- return jax.tree_util.tree_map(permute_array, data)
136
+ return jax.tree.map(permute_array, data)
137
137
 
138
138
 
139
139
  TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
@@ -215,7 +215,7 @@ def reshuffle_pytree_along_dims(
215
215
  transpose_info[path] = (transpose_order, original_shape)
216
216
  return x
217
217
 
218
- jax.tree_util.tree_map_with_path(prepare_for_shuffle, data)
218
+ jax.tree.map_with_path(prepare_for_shuffle, data)
219
219
 
220
220
  # Create a transposed pytree
221
221
  def get_transposed(path: PathType, x: PyTree) -> PyTree:
@@ -223,7 +223,7 @@ def reshuffle_pytree_along_dims(
223
223
  return transposed_data[path]
224
224
  return x
225
225
 
226
- transposed_pytree = jax.tree_util.tree_map_with_path(get_transposed, data)
226
+ transposed_pytree = jax.tree.map_with_path(get_transposed, data)
227
227
 
228
228
  # Reshuffle the transposed pytree along the leading dimensions
229
229
  reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
@@ -235,4 +235,4 @@ def reshuffle_pytree_along_dims(
235
235
  return transpose_back(x, transpose_order, original_shape)
236
236
  return x
237
237
 
238
- return jax.tree_util.tree_map_with_path(restore_transpose, reshuffled_transposed)
238
+ return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,10 +1,10 @@
1
- xax/__init__.py,sha256=9WNjoeAF7enu7YXQqshpVG1FucGdSkxwrRa-ELDDuUs,13713
1
+ xax/__init__.py,sha256=bvOBMlEVA46I7ILGfk5AbpwpcdTAjw-4vWI7ci7L7-g,13392
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
5
5
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
7
- xax/core/state.py,sha256=y123fL7pMgk25TPG6KN0LRIF_eYnD9eP7OfqtoQJGNE,2178
7
+ xax/core/state.py,sha256=WwW0qDm-be9MMOT-bGWEFvaWF4iq2FP9xRSn1zq_4A8,2507
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
@@ -14,9 +14,9 @@ xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
14
14
  xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
15
15
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
16
16
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- xax/task/base.py,sha256=4fUjrG-llQpeESQuaQbww4M6WR6djjTK89fY20UV9zU,7610
17
+ xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
18
18
  xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
19
- xax/task/script.py,sha256=zt36Sobdoer86gXHqc4sMAW7bqZRVl6IEExuQZH2USk,926
19
+ xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
20
20
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
21
21
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
@@ -39,7 +39,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
39
39
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
40
40
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
41
41
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
- xax/task/mixins/train.py,sha256=vsH_QpyrThlh9AzWnyvDJv58Y8U_516oi8gmMq_0iMg,22333
42
+ xax/task/mixins/train.py,sha256=jAzc9RD25DbhekvItzsRQQrK9aEwtA_sXy0m2Hfkuxo,24594
43
43
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
45
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
@@ -48,7 +48,7 @@ xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
48
48
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
49
49
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
50
50
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
51
- xax/utils/pytree.py,sha256=7GjQoPc_ZSZt3QS_9qXoBWl1jfMp1qZa7aViQoWJ0OQ,8864
51
+ xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
52
52
  xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,17058
53
53
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
54
54
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -56,8 +56,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
56
56
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
57
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
58
58
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
59
- xax-0.1.8.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
60
- xax-0.1.8.dist-info/METADATA,sha256=wnBSNRByXJzgQPuZqNWooidfFdqcT4w8gbwlBgzbJk8,1877
61
- xax-0.1.8.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
62
- xax-0.1.8.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
63
- xax-0.1.8.dist-info/RECORD,,
59
+ xax-0.1.10.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
60
+ xax-0.1.10.dist-info/METADATA,sha256=kJ1lxZ6cWrtJ5R-adTorzEE_1l0VRJ67xfuBjYXG9Vo,1878
61
+ xax-0.1.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
62
+ xax-0.1.10.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
63
+ xax-0.1.10.dist-info/RECORD,,
File without changes