xax 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {xax-0.1.2/xax.egg-info → xax-0.1.3}/PKG-INFO +1 -1
  2. {xax-0.1.2 → xax-0.1.3}/xax/__init__.py +1 -1
  3. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/checkpointing.py +0 -10
  4. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/train.py +2 -0
  5. {xax-0.1.2 → xax-0.1.3}/xax/utils/experiments.py +11 -0
  6. {xax-0.1.2 → xax-0.1.3/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.1.2 → xax-0.1.3}/LICENSE +0 -0
  8. {xax-0.1.2 → xax-0.1.3}/MANIFEST.in +0 -0
  9. {xax-0.1.2 → xax-0.1.3}/README.md +0 -0
  10. {xax-0.1.2 → xax-0.1.3}/pyproject.toml +0 -0
  11. {xax-0.1.2 → xax-0.1.3}/setup.cfg +0 -0
  12. {xax-0.1.2 → xax-0.1.3}/setup.py +0 -0
  13. {xax-0.1.2 → xax-0.1.3}/xax/core/__init__.py +0 -0
  14. {xax-0.1.2 → xax-0.1.3}/xax/core/conf.py +0 -0
  15. {xax-0.1.2 → xax-0.1.3}/xax/core/state.py +0 -0
  16. {xax-0.1.2 → xax-0.1.3}/xax/nn/__init__.py +0 -0
  17. {xax-0.1.2 → xax-0.1.3}/xax/nn/embeddings.py +0 -0
  18. {xax-0.1.2 → xax-0.1.3}/xax/nn/equinox.py +0 -0
  19. {xax-0.1.2 → xax-0.1.3}/xax/nn/export.py +0 -0
  20. {xax-0.1.2 → xax-0.1.3}/xax/nn/functions.py +0 -0
  21. {xax-0.1.2 → xax-0.1.3}/xax/nn/geom.py +0 -0
  22. {xax-0.1.2 → xax-0.1.3}/xax/nn/norm.py +0 -0
  23. {xax-0.1.2 → xax-0.1.3}/xax/nn/parallel.py +0 -0
  24. {xax-0.1.2 → xax-0.1.3}/xax/py.typed +0 -0
  25. {xax-0.1.2 → xax-0.1.3}/xax/requirements-dev.txt +0 -0
  26. {xax-0.1.2 → xax-0.1.3}/xax/requirements.txt +0 -0
  27. {xax-0.1.2 → xax-0.1.3}/xax/task/__init__.py +0 -0
  28. {xax-0.1.2 → xax-0.1.3}/xax/task/base.py +0 -0
  29. {xax-0.1.2 → xax-0.1.3}/xax/task/launchers/__init__.py +0 -0
  30. {xax-0.1.2 → xax-0.1.3}/xax/task/launchers/base.py +0 -0
  31. {xax-0.1.2 → xax-0.1.3}/xax/task/launchers/cli.py +0 -0
  32. {xax-0.1.2 → xax-0.1.3}/xax/task/launchers/single_process.py +0 -0
  33. {xax-0.1.2 → xax-0.1.3}/xax/task/logger.py +0 -0
  34. {xax-0.1.2 → xax-0.1.3}/xax/task/loggers/__init__.py +0 -0
  35. {xax-0.1.2 → xax-0.1.3}/xax/task/loggers/callback.py +0 -0
  36. {xax-0.1.2 → xax-0.1.3}/xax/task/loggers/json.py +0 -0
  37. {xax-0.1.2 → xax-0.1.3}/xax/task/loggers/state.py +0 -0
  38. {xax-0.1.2 → xax-0.1.3}/xax/task/loggers/stdout.py +0 -0
  39. {xax-0.1.2 → xax-0.1.3}/xax/task/loggers/tensorboard.py +0 -0
  40. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/__init__.py +0 -0
  41. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/artifacts.py +0 -0
  42. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/compile.py +0 -0
  43. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/cpu_stats.py +0 -0
  44. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/data_loader.py +0 -0
  45. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/gpu_stats.py +0 -0
  46. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/logger.py +0 -0
  47. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/process.py +0 -0
  48. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/runnable.py +0 -0
  49. {xax-0.1.2 → xax-0.1.3}/xax/task/mixins/step_wrapper.py +0 -0
  50. {xax-0.1.2 → xax-0.1.3}/xax/task/script.py +0 -0
  51. {xax-0.1.2 → xax-0.1.3}/xax/task/task.py +0 -0
  52. {xax-0.1.2 → xax-0.1.3}/xax/utils/__init__.py +0 -0
  53. {xax-0.1.2 → xax-0.1.3}/xax/utils/data/__init__.py +0 -0
  54. {xax-0.1.2 → xax-0.1.3}/xax/utils/data/collate.py +0 -0
  55. {xax-0.1.2 → xax-0.1.3}/xax/utils/debugging.py +0 -0
  56. {xax-0.1.2 → xax-0.1.3}/xax/utils/jax.py +0 -0
  57. {xax-0.1.2 → xax-0.1.3}/xax/utils/jaxpr.py +0 -0
  58. {xax-0.1.2 → xax-0.1.3}/xax/utils/logging.py +0 -0
  59. {xax-0.1.2 → xax-0.1.3}/xax/utils/numpy.py +0 -0
  60. {xax-0.1.2 → xax-0.1.3}/xax/utils/profile.py +0 -0
  61. {xax-0.1.2 → xax-0.1.3}/xax/utils/pytree.py +0 -0
  62. {xax-0.1.2 → xax-0.1.3}/xax/utils/tensorboard.py +0 -0
  63. {xax-0.1.2 → xax-0.1.3}/xax/utils/text.py +0 -0
  64. {xax-0.1.2 → xax-0.1.3}/xax.egg-info/SOURCES.txt +0 -0
  65. {xax-0.1.2 → xax-0.1.3}/xax.egg-info/dependency_links.txt +0 -0
  66. {xax-0.1.2 → xax-0.1.3}/xax.egg-info/requires.txt +0 -0
  67. {xax-0.1.2 → xax-0.1.3}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.2"
15
+ __version__ = "0.1.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -47,7 +47,6 @@ class CheckpointingConfig(ArtifactsConfig):
47
47
  only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
48
48
  load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
49
49
  load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
50
- save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
51
50
 
52
51
 
53
52
  Config = TypeVar("Config", bound=CheckpointingConfig)
@@ -213,15 +212,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
213
212
  add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
214
213
  add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
215
214
 
216
- if self.config.save_tf_model:
217
- try:
218
- from jax.experimental import jax2tf
219
- except ModuleNotFoundError:
220
- raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
221
-
222
- tf_model = jax2tf.convert(model)
223
- add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
224
-
225
215
  # Updates the symlink to the new checkpoint.
226
216
  last_ckpt_path.unlink(missing_ok=True)
227
217
  try:
@@ -50,6 +50,7 @@ from xax.utils.experiments import (
50
50
  diff_configs,
51
51
  get_diff_string,
52
52
  get_git_state,
53
+ get_packages_with_versions,
53
54
  get_training_code,
54
55
  )
55
56
  from xax.utils.logging import LOG_STATUS
@@ -452,6 +453,7 @@ class TrainMixin(
452
453
  logger.log(LOG_STATUS, self.task_name)
453
454
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
454
455
  self.logger.log_file("git_state.txt", get_git_state(self))
456
+ self.logger.log_file("packages.txt", get_packages_with_versions())
455
457
  self.logger.log_file("training_code.txt", get_training_code(self))
456
458
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
457
459
 
@@ -28,6 +28,7 @@ from typing import Any, Iterator, Self, TypeVar, cast
28
28
  from urllib.parse import urlparse
29
29
 
30
30
  import git
31
+ import pkg_resources
31
32
  import requests
32
33
  from jaxtyping import Array
33
34
  from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
@@ -468,6 +469,16 @@ def get_git_state(obj: object) -> str:
468
469
  return traceback.format_exc()
469
470
 
470
471
 
472
+ def get_packages_with_versions() -> str:
473
+ """Gets the packages and their versions.
474
+
475
+ Returns:
476
+ A dictionary of packages and their versions.
477
+ """
478
+ packages = [(pkg.key, pkg.version) for pkg in pkg_resources.working_set]
479
+ return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
480
+
481
+
471
482
  def get_training_code(obj: object) -> str:
472
483
  """Gets the text from the file containing the provided object.
473
484
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes