xax 0.1.9__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.9"
15
+ __version__ = "0.1.10"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/task/mixins/train.py CHANGED
@@ -56,6 +56,7 @@ from xax.utils.experiments import (
56
56
  from xax.utils.jax import jit as xax_jit
57
57
  from xax.utils.logging import LOG_STATUS
58
58
  from xax.utils.text import highlight_exception_message, show_info
59
+ from xax.utils.types.frozen_dict import FrozenDict
59
60
 
60
61
  logger = logging.getLogger(__name__)
61
62
 
@@ -215,7 +216,7 @@ class TrainMixin(
215
216
  state = super().on_step_end(state)
216
217
  return state.replace(elapsed_time_s=time.time() - state.start_time_s)
217
218
 
218
- def log_train_step(self, batch: Batch, output: Output, state: State) -> None:
219
+ def log_train_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
219
220
  """Override this function to do logging during the training phase.
220
221
 
221
222
  This function is called after the model forward pass and before the
@@ -224,10 +225,11 @@ class TrainMixin(
224
225
  Args:
225
226
  batch: The batch from the dataloader.
226
227
  output: The model output.
228
+ metrics: The metrics for the current batch.
227
229
  state: The current training state.
228
230
  """
229
231
 
230
- def log_valid_step(self, batch: Batch, output: Output, state: State) -> None:
232
+ def log_valid_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
231
233
  """Override this function to do logging during the validation phase.
232
234
 
233
235
  This function is called after the model forward pass. It is called in
@@ -236,6 +238,7 @@ class TrainMixin(
236
238
  Args:
237
239
  batch: The batch from the dataloader.
238
240
  output: The model output.
241
+ metrics: The metrics for the current batch.
239
242
  state: The current training state.
240
243
  """
241
244
 
@@ -246,18 +249,23 @@ class TrainMixin(
246
249
  for k, v in d.items():
247
250
  self.logger.log_scalar(k, v, namespace=ns)
248
251
 
249
- def log_step(self, batch: Batch, output: Output, loss: Array, state: State) -> None:
252
+ def log_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
250
253
  phase = state.phase
251
254
 
252
- self.logger.log_scalar("loss", loss, namespace="loss")
255
+ for k, v in metrics.items():
256
+ if v.size == 1:
257
+ self.logger.log_scalar(k, v.item())
258
+ else:
259
+ self.logger.log_histogram(k, v)
260
+
253
261
  self.log_state_timers(state)
254
262
 
255
263
  # Delegate to the appropriate logging function based on the phase.
256
264
  match phase:
257
265
  case "train":
258
- self.log_train_step(batch, output, state)
266
+ self.log_train_step(batch, output, metrics, state)
259
267
  case "valid":
260
- self.log_valid_step(batch, output, state)
268
+ self.log_valid_step(batch, output, metrics, state)
261
269
  case _:
262
270
  raise KeyError(f"Unknown phase: {phase}")
263
271
 
@@ -364,32 +372,59 @@ class TrainMixin(
364
372
  raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
365
373
  return output
366
374
 
375
+ def compute_metrics(
376
+ self,
377
+ model: PyTree,
378
+ batch: Batch,
379
+ output: Output,
380
+ loss: Array,
381
+ state: State,
382
+ ) -> dict[str, Array]:
383
+ """Computes the metrics for the current batch.
384
+
385
+ Args:
386
+ model: The current model.
387
+ batch: The current minibatch of samples.
388
+ output: The output from the model.
389
+ loss: The loss for the current batch.
390
+ state: The current training state.
391
+
392
+ Returns:
393
+ A dictionary of metrics.
394
+ """
395
+ return {
396
+ "loss": loss,
397
+ }
398
+
399
+ @xax_jit(static_argnames=["self", "model_static"])
367
400
  def get_output_and_loss(
368
401
  self,
369
- model_static: PyTree,
370
402
  model_arr: PyTree,
403
+ model_static: PyTree,
371
404
  batch: Batch,
372
405
  state: State,
373
- ) -> tuple[Array, Output]:
406
+ ) -> tuple[Array, tuple[Output, FrozenDict[str, Array]]]:
374
407
  model = eqx.combine(model_arr, model_static)
375
408
  output = self.get_output(model, batch, state)
376
409
  loss = self.compute_loss(model, batch, output, state)
377
- return loss, output
410
+ metrics = self.compute_metrics(model, batch, output, loss, state)
411
+ return loss, (output, FrozenDict(metrics))
378
412
 
379
413
  def update(
380
414
  self,
381
- model_static: PyTree,
382
415
  model_arr: PyTree,
416
+ model_static: PyTree,
383
417
  optimizer: optax.GradientTransformation,
384
418
  opt_state: optax.OptState,
385
419
  batch: Batch,
386
420
  state: State,
387
- ) -> tuple[Array, PyTree, optax.OptState, Output]:
388
- grad_fn = eqx.filter_value_and_grad(self.get_output_and_loss, has_aux=True)
389
- (loss, output), grads = grad_fn(model_static, model_arr, batch, state)
421
+ ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
422
+ grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
423
+ grad_fn = xax_jit(static_argnums=[1])(grad_fn)
424
+ grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
390
425
  updates, opt_state = optimizer.update(grads, opt_state, model_arr)
391
426
  model_arr = eqx.apply_updates(model_arr, updates)
392
- return loss, model_arr, opt_state, output
427
+ return model_arr, opt_state, output, metrics
393
428
 
394
429
  def get_size_of_batch(self, batch: Batch) -> int | None:
395
430
  """Gets the batch size for the current batch.
@@ -469,25 +504,26 @@ class TrainMixin(
469
504
  @xax_jit(static_argnames=["self", "model_static", "optimizer"])
470
505
  def train_step(
471
506
  self,
472
- model_static: PyTree,
473
507
  model_arr: PyTree,
508
+ model_static: PyTree,
474
509
  optimizer: optax.GradientTransformation,
475
510
  opt_state: optax.OptState,
476
511
  batch: Batch,
477
512
  state: State,
478
- ) -> tuple[PyTree, optax.OptState, Array, Output]:
479
- loss, model_arr, opt_state, output = self.update(model_static, model_arr, optimizer, opt_state, batch, state)
480
- return model_arr, opt_state, loss, output
513
+ ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
514
+ model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
515
+ return model_arr, opt_state, output, metrics
481
516
 
482
517
  @xax_jit(static_argnames=["self", "model_static"])
483
518
  def val_step(
484
519
  self,
485
- model_static: PyTree,
486
520
  model_arr: PyTree,
521
+ model_static: PyTree,
487
522
  batch: Batch,
488
523
  state: State,
489
- ) -> tuple[Array, Output]:
490
- return self.get_output_and_loss(model_static, model_arr, batch, state)
524
+ ) -> tuple[Output, FrozenDict[str, Array]]:
525
+ _, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
526
+ return output, metrics
491
527
 
492
528
  def train_loop(
493
529
  self,
@@ -509,8 +545,8 @@ class TrainMixin(
509
545
  num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
510
546
  )
511
547
 
512
- loss, output = self.val_step(model_static, model_arr, valid_batch, state)
513
- self.log_step(valid_batch, output, loss, state)
548
+ output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
549
+ self.log_step(valid_batch, output, metrics, state)
514
550
 
515
551
  state = self.on_step_start(state)
516
552
  train_batch = next(train_pf)
@@ -520,15 +556,15 @@ class TrainMixin(
520
556
  num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
521
557
  )
522
558
 
523
- model_arr, opt_state, loss, output = self.train_step(
524
- model_static=model_static,
559
+ model_arr, opt_state, output, metrics = self.train_step(
525
560
  model_arr=model_arr,
561
+ model_static=model_static,
526
562
  optimizer=optimizer,
527
563
  opt_state=opt_state,
528
564
  batch=train_batch,
529
565
  state=state,
530
566
  )
531
- self.log_step(train_batch, output, loss, state)
567
+ self.log_step(train_batch, output, metrics, state)
532
568
 
533
569
  state = self.on_step_end(state)
534
570
 
xax/utils/pytree.py CHANGED
@@ -31,7 +31,7 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
31
31
 
32
32
  def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
33
33
  """Get a slice of a pytree."""
34
- return jax.tree_util.tree_map(lambda x: slice_array(x, start, slice_length), pytree)
34
+ return jax.tree.map(lambda x: slice_array(x, start, slice_length), pytree)
35
35
 
36
36
 
37
37
  def flatten_array(x: Array, flatten_size: int) -> Array:
@@ -43,14 +43,14 @@ def flatten_array(x: Array, flatten_size: int) -> Array:
43
43
 
44
44
  def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
45
45
  """Flatten a pytree into a (flatten_size, ...) pytree."""
46
- return jax.tree_util.tree_map(lambda x: flatten_array(x, flatten_size), pytree)
46
+ return jax.tree.map(lambda x: flatten_array(x, flatten_size), pytree)
47
47
 
48
48
 
49
49
  def pytree_has_nans(pytree: PyTree) -> Array:
50
50
  """Check if a pytree has any NaNs."""
51
51
  has_nans = jax.tree_util.tree_reduce(
52
52
  lambda a, b: jnp.logical_or(a, b),
53
- jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), pytree),
53
+ jax.tree.map(lambda x: jnp.any(jnp.isnan(x)), pytree),
54
54
  )
55
55
  return has_nans
56
56
 
@@ -58,13 +58,13 @@ def pytree_has_nans(pytree: PyTree) -> Array:
58
58
  def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
59
59
  """Update a pytree based on a condition."""
60
60
  # Tricky, need use tree_map because where expects array leafs.
61
- return jax.tree_util.tree_map(lambda x, y: jnp.where(cond, x, y), new, original)
61
+ return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
62
62
 
63
63
 
64
64
  def compute_nan_ratio(pytree: PyTree) -> Array:
65
65
  """Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
66
- nan_counts = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
67
- total_counts = jax.tree_util.tree_map(lambda x: x.size, pytree)
66
+ nan_counts = jax.tree.map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
67
+ total_counts = jax.tree.map(lambda x: x.size, pytree)
68
68
 
69
69
  total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
70
70
  total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
@@ -118,7 +118,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
118
118
  # Reshape back to the original shape
119
119
  return permuted.reshape(orig_shape)
120
120
 
121
- return jax.tree_util.tree_map(permute_array, data)
121
+ return jax.tree.map(permute_array, data)
122
122
 
123
123
 
124
124
  def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
@@ -133,7 +133,7 @@ def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], r
133
133
  return x[tuple(idx_grids)]
134
134
  return x
135
135
 
136
- return jax.tree_util.tree_map(permute_array, data)
136
+ return jax.tree.map(permute_array, data)
137
137
 
138
138
 
139
139
  TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
@@ -215,7 +215,7 @@ def reshuffle_pytree_along_dims(
215
215
  transpose_info[path] = (transpose_order, original_shape)
216
216
  return x
217
217
 
218
- jax.tree_util.tree_map_with_path(prepare_for_shuffle, data)
218
+ jax.tree.map_with_path(prepare_for_shuffle, data)
219
219
 
220
220
  # Create a transposed pytree
221
221
  def get_transposed(path: PathType, x: PyTree) -> PyTree:
@@ -223,7 +223,7 @@ def reshuffle_pytree_along_dims(
223
223
  return transposed_data[path]
224
224
  return x
225
225
 
226
- transposed_pytree = jax.tree_util.tree_map_with_path(get_transposed, data)
226
+ transposed_pytree = jax.tree.map_with_path(get_transposed, data)
227
227
 
228
228
  # Reshuffle the transposed pytree along the leading dimensions
229
229
  reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
@@ -235,4 +235,4 @@ def reshuffle_pytree_along_dims(
235
235
  return transpose_back(x, transpose_order, original_shape)
236
236
  return x
237
237
 
238
- return jax.tree_util.tree_map_with_path(restore_transpose, reshuffled_transposed)
238
+ return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.9
3
+ Version: 0.1.10
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=_xb60-jl7arZEleSwUw4ElPaq4MzD24_ZYQrnWO5_cs,13391
1
+ xax/__init__.py,sha256=bvOBMlEVA46I7ILGfk5AbpwpcdTAjw-4vWI7ci7L7-g,13392
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -39,7 +39,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
39
39
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
40
40
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
41
41
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
- xax/task/mixins/train.py,sha256=JbrSiBqpgOrdDanNYuAzzh2radPrXOVrHYA6VcxjIzY,23248
42
+ xax/task/mixins/train.py,sha256=jAzc9RD25DbhekvItzsRQQrK9aEwtA_sXy0m2Hfkuxo,24594
43
43
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
45
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
@@ -48,7 +48,7 @@ xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
48
48
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
49
49
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
50
50
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
51
- xax/utils/pytree.py,sha256=7GjQoPc_ZSZt3QS_9qXoBWl1jfMp1qZa7aViQoWJ0OQ,8864
51
+ xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
52
52
  xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,17058
53
53
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
54
54
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -56,8 +56,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
56
56
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
57
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
58
58
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
59
- xax-0.1.9.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
60
- xax-0.1.9.dist-info/METADATA,sha256=Ou8KmYWWNxgo_9ZAU2KLaeGeXAxd6b9qJ95ky4HRm-o,1877
61
- xax-0.1.9.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
62
- xax-0.1.9.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
63
- xax-0.1.9.dist-info/RECORD,,
59
+ xax-0.1.10.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
60
+ xax-0.1.10.dist-info/METADATA,sha256=kJ1lxZ6cWrtJ5R-adTorzEE_1l0VRJ67xfuBjYXG9Vo,1878
61
+ xax-0.1.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
62
+ xax-0.1.10.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
63
+ xax-0.1.10.dist-info/RECORD,,
File without changes