xax 0.2.8__tar.gz → 0.2.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.8/xax.egg-info → xax-0.2.10}/PKG-INFO +1 -1
  2. {xax-0.2.8 → xax-0.2.10}/xax/__init__.py +2 -2
  3. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/train.py +3 -4
  4. {xax-0.2.8 → xax-0.2.10/xax.egg-info}/PKG-INFO +1 -1
  5. {xax-0.2.8 → xax-0.2.10}/LICENSE +0 -0
  6. {xax-0.2.8 → xax-0.2.10}/MANIFEST.in +0 -0
  7. {xax-0.2.8 → xax-0.2.10}/README.md +0 -0
  8. {xax-0.2.8 → xax-0.2.10}/pyproject.toml +0 -0
  9. {xax-0.2.8 → xax-0.2.10}/setup.cfg +0 -0
  10. {xax-0.2.8 → xax-0.2.10}/setup.py +0 -0
  11. {xax-0.2.8 → xax-0.2.10}/xax/core/__init__.py +0 -0
  12. {xax-0.2.8 → xax-0.2.10}/xax/core/conf.py +0 -0
  13. {xax-0.2.8 → xax-0.2.10}/xax/core/state.py +0 -0
  14. {xax-0.2.8 → xax-0.2.10}/xax/nn/__init__.py +0 -0
  15. {xax-0.2.8 → xax-0.2.10}/xax/nn/embeddings.py +0 -0
  16. {xax-0.2.8 → xax-0.2.10}/xax/nn/equinox.py +0 -0
  17. {xax-0.2.8 → xax-0.2.10}/xax/nn/export.py +0 -0
  18. {xax-0.2.8 → xax-0.2.10}/xax/nn/functions.py +0 -0
  19. {xax-0.2.8 → xax-0.2.10}/xax/nn/geom.py +0 -0
  20. {xax-0.2.8 → xax-0.2.10}/xax/nn/losses.py +0 -0
  21. {xax-0.2.8 → xax-0.2.10}/xax/nn/norm.py +0 -0
  22. {xax-0.2.8 → xax-0.2.10}/xax/nn/parallel.py +0 -0
  23. {xax-0.2.8 → xax-0.2.10}/xax/nn/ssm.py +0 -0
  24. {xax-0.2.8 → xax-0.2.10}/xax/py.typed +0 -0
  25. {xax-0.2.8 → xax-0.2.10}/xax/requirements-dev.txt +0 -0
  26. {xax-0.2.8 → xax-0.2.10}/xax/requirements.txt +0 -0
  27. {xax-0.2.8 → xax-0.2.10}/xax/task/__init__.py +0 -0
  28. {xax-0.2.8 → xax-0.2.10}/xax/task/base.py +0 -0
  29. {xax-0.2.8 → xax-0.2.10}/xax/task/launchers/__init__.py +0 -0
  30. {xax-0.2.8 → xax-0.2.10}/xax/task/launchers/base.py +0 -0
  31. {xax-0.2.8 → xax-0.2.10}/xax/task/launchers/cli.py +0 -0
  32. {xax-0.2.8 → xax-0.2.10}/xax/task/launchers/single_process.py +0 -0
  33. {xax-0.2.8 → xax-0.2.10}/xax/task/logger.py +0 -0
  34. {xax-0.2.8 → xax-0.2.10}/xax/task/loggers/__init__.py +0 -0
  35. {xax-0.2.8 → xax-0.2.10}/xax/task/loggers/callback.py +0 -0
  36. {xax-0.2.8 → xax-0.2.10}/xax/task/loggers/json.py +0 -0
  37. {xax-0.2.8 → xax-0.2.10}/xax/task/loggers/state.py +0 -0
  38. {xax-0.2.8 → xax-0.2.10}/xax/task/loggers/stdout.py +0 -0
  39. {xax-0.2.8 → xax-0.2.10}/xax/task/loggers/tensorboard.py +0 -0
  40. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/__init__.py +0 -0
  41. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/artifacts.py +0 -0
  42. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/checkpointing.py +0 -0
  43. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/compile.py +0 -0
  44. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/cpu_stats.py +0 -0
  45. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/data_loader.py +0 -0
  46. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/gpu_stats.py +0 -0
  47. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/process.py +0 -0
  49. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.2.8 → xax-0.2.10}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.2.8 → xax-0.2.10}/xax/task/script.py +0 -0
  52. {xax-0.2.8 → xax-0.2.10}/xax/task/task.py +0 -0
  53. {xax-0.2.8 → xax-0.2.10}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.8 → xax-0.2.10}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.8 → xax-0.2.10}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.8 → xax-0.2.10}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.8 → xax-0.2.10}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.8 → xax-0.2.10}/xax/utils/jax.py +0 -0
  59. {xax-0.2.8 → xax-0.2.10}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.8 → xax-0.2.10}/xax/utils/logging.py +0 -0
  61. {xax-0.2.8 → xax-0.2.10}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.8 → xax-0.2.10}/xax/utils/profile.py +0 -0
  63. {xax-0.2.8 → xax-0.2.10}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.8 → xax-0.2.10}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.8 → xax-0.2.10}/xax/utils/text.py +0 -0
  66. {xax-0.2.8 → xax-0.2.10}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.8 → xax-0.2.10}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.8 → xax-0.2.10}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.8 → xax-0.2.10}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.8 → xax-0.2.10}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.8 → xax-0.2.10}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.8 → xax-0.2.10}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.8"
15
+ __version__ = "0.2.10"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -308,7 +308,7 @@ NAME_MAP: dict[str, str] = {
308
308
  "compute_nan_ratio": "utils.pytree",
309
309
  "flatten_array": "utils.pytree",
310
310
  "flatten_pytree": "utils.pytree",
311
- "get_param_count": "utils.pytree",
311
+ "get_pytree_param_count": "utils.pytree",
312
312
  "pytree_has_nans": "utils.pytree",
313
313
  "reshuffle_pytree": "utils.pytree",
314
314
  "reshuffle_pytree_along_dims": "utils.pytree",
@@ -361,7 +361,6 @@ class TrainMixin(
361
361
  model = self.get_model(key)
362
362
  state = State.init_state()
363
363
 
364
- self.log_model_size(model)
365
364
  if not load_optimizer:
366
365
  return model, state
367
366
 
@@ -684,9 +683,6 @@ class TrainMixin(
684
683
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
685
684
  self.logger.log_file("info.json", get_info_json())
686
685
 
687
- def log_model_size(self, model: PyTree) -> None:
688
- logger.info("Model size: %s", f"{get_pytree_param_count(model):,}")
689
-
690
686
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
691
687
  return eqx.is_inexact_array(item)
692
688
 
@@ -832,6 +828,9 @@ class TrainMixin(
832
828
 
833
829
  key, model_key = jax.random.split(key)
834
830
  model, optimizer, opt_state, state = self.load_initial_state(model_key, load_optimizer=True)
831
+ logger.info("Model size: %s", f"{get_pytree_param_count(model):,}")
832
+ logger.info("Optimizer size: %s", f"{get_pytree_param_count(optimizer):,}")
833
+
835
834
  state = self.on_training_start(state)
836
835
 
837
836
  def on_exit() -> None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes