xax 0.1.1__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {xax-0.1.1/xax.egg-info → xax-0.1.3}/PKG-INFO +1 -1
  2. {xax-0.1.1 → xax-0.1.3}/xax/__init__.py +4 -2
  3. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/checkpointing.py +0 -10
  4. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/train.py +2 -0
  5. {xax-0.1.1 → xax-0.1.3}/xax/utils/experiments.py +13 -2
  6. {xax-0.1.1 → xax-0.1.3/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.1.1 → xax-0.1.3}/LICENSE +0 -0
  8. {xax-0.1.1 → xax-0.1.3}/MANIFEST.in +0 -0
  9. {xax-0.1.1 → xax-0.1.3}/README.md +0 -0
  10. {xax-0.1.1 → xax-0.1.3}/pyproject.toml +0 -0
  11. {xax-0.1.1 → xax-0.1.3}/setup.cfg +0 -0
  12. {xax-0.1.1 → xax-0.1.3}/setup.py +0 -0
  13. {xax-0.1.1 → xax-0.1.3}/xax/core/__init__.py +0 -0
  14. {xax-0.1.1 → xax-0.1.3}/xax/core/conf.py +0 -0
  15. {xax-0.1.1 → xax-0.1.3}/xax/core/state.py +0 -0
  16. {xax-0.1.1 → xax-0.1.3}/xax/nn/__init__.py +0 -0
  17. {xax-0.1.1 → xax-0.1.3}/xax/nn/embeddings.py +0 -0
  18. {xax-0.1.1 → xax-0.1.3}/xax/nn/equinox.py +0 -0
  19. {xax-0.1.1 → xax-0.1.3}/xax/nn/export.py +0 -0
  20. {xax-0.1.1 → xax-0.1.3}/xax/nn/functions.py +0 -0
  21. {xax-0.1.1 → xax-0.1.3}/xax/nn/geom.py +0 -0
  22. {xax-0.1.1 → xax-0.1.3}/xax/nn/norm.py +0 -0
  23. {xax-0.1.1 → xax-0.1.3}/xax/nn/parallel.py +0 -0
  24. {xax-0.1.1 → xax-0.1.3}/xax/py.typed +0 -0
  25. {xax-0.1.1 → xax-0.1.3}/xax/requirements-dev.txt +0 -0
  26. {xax-0.1.1 → xax-0.1.3}/xax/requirements.txt +0 -0
  27. {xax-0.1.1 → xax-0.1.3}/xax/task/__init__.py +0 -0
  28. {xax-0.1.1 → xax-0.1.3}/xax/task/base.py +0 -0
  29. {xax-0.1.1 → xax-0.1.3}/xax/task/launchers/__init__.py +0 -0
  30. {xax-0.1.1 → xax-0.1.3}/xax/task/launchers/base.py +0 -0
  31. {xax-0.1.1 → xax-0.1.3}/xax/task/launchers/cli.py +0 -0
  32. {xax-0.1.1 → xax-0.1.3}/xax/task/launchers/single_process.py +0 -0
  33. {xax-0.1.1 → xax-0.1.3}/xax/task/logger.py +0 -0
  34. {xax-0.1.1 → xax-0.1.3}/xax/task/loggers/__init__.py +0 -0
  35. {xax-0.1.1 → xax-0.1.3}/xax/task/loggers/callback.py +0 -0
  36. {xax-0.1.1 → xax-0.1.3}/xax/task/loggers/json.py +0 -0
  37. {xax-0.1.1 → xax-0.1.3}/xax/task/loggers/state.py +0 -0
  38. {xax-0.1.1 → xax-0.1.3}/xax/task/loggers/stdout.py +0 -0
  39. {xax-0.1.1 → xax-0.1.3}/xax/task/loggers/tensorboard.py +0 -0
  40. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/__init__.py +0 -0
  41. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/artifacts.py +0 -0
  42. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/compile.py +0 -0
  43. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/cpu_stats.py +0 -0
  44. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/data_loader.py +0 -0
  45. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/gpu_stats.py +0 -0
  46. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/logger.py +0 -0
  47. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/process.py +0 -0
  48. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/runnable.py +0 -0
  49. {xax-0.1.1 → xax-0.1.3}/xax/task/mixins/step_wrapper.py +0 -0
  50. {xax-0.1.1 → xax-0.1.3}/xax/task/script.py +0 -0
  51. {xax-0.1.1 → xax-0.1.3}/xax/task/task.py +0 -0
  52. {xax-0.1.1 → xax-0.1.3}/xax/utils/__init__.py +0 -0
  53. {xax-0.1.1 → xax-0.1.3}/xax/utils/data/__init__.py +0 -0
  54. {xax-0.1.1 → xax-0.1.3}/xax/utils/data/collate.py +0 -0
  55. {xax-0.1.1 → xax-0.1.3}/xax/utils/debugging.py +0 -0
  56. {xax-0.1.1 → xax-0.1.3}/xax/utils/jax.py +0 -0
  57. {xax-0.1.1 → xax-0.1.3}/xax/utils/jaxpr.py +0 -0
  58. {xax-0.1.1 → xax-0.1.3}/xax/utils/logging.py +0 -0
  59. {xax-0.1.1 → xax-0.1.3}/xax/utils/numpy.py +0 -0
  60. {xax-0.1.1 → xax-0.1.3}/xax/utils/profile.py +0 -0
  61. {xax-0.1.1 → xax-0.1.3}/xax/utils/pytree.py +0 -0
  62. {xax-0.1.1 → xax-0.1.3}/xax/utils/tensorboard.py +0 -0
  63. {xax-0.1.1 → xax-0.1.3}/xax/utils/text.py +0 -0
  64. {xax-0.1.1 → xax-0.1.3}/xax.egg-info/SOURCES.txt +0 -0
  65. {xax-0.1.1 → xax-0.1.3}/xax.egg-info/dependency_links.txt +0 -0
  66. {xax-0.1.1 → xax-0.1.3}/xax.egg-info/requires.txt +0 -0
  67. {xax-0.1.1 → xax-0.1.3}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.1"
15
+ __version__ = "0.1.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -66,6 +66,7 @@ __all__ = [
66
66
  "DataloaderConfig",
67
67
  "GPUStatsOptions",
68
68
  "StepContext",
69
+ "ValidStepTimer",
69
70
  "Script",
70
71
  "ScriptConfig",
71
72
  "Config",
@@ -219,6 +220,7 @@ NAME_MAP: dict[str, str] = {
219
220
  "DataloaderConfig": "task.mixins.data_loader",
220
221
  "GPUStatsOptions": "task.mixins.gpu_stats",
221
222
  "StepContext": "task.mixins.step_wrapper",
223
+ "ValidStepTimer": "task.mixins.train",
222
224
  "Script": "task.script",
223
225
  "ScriptConfig": "task.script",
224
226
  "Config": "task.task",
@@ -372,7 +374,7 @@ if IMPORT_ALL or TYPE_CHECKING:
372
374
  from xax.task.mixins.data_loader import DataloaderConfig
373
375
  from xax.task.mixins.gpu_stats import GPUStatsOptions
374
376
  from xax.task.mixins.step_wrapper import StepContext
375
- from xax.task.mixins.train import Batch, Output
377
+ from xax.task.mixins.train import Batch, Output, ValidStepTimer
376
378
  from xax.task.script import Script, ScriptConfig
377
379
  from xax.task.task import Config, Task
378
380
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
@@ -47,7 +47,6 @@ class CheckpointingConfig(ArtifactsConfig):
47
47
  only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
48
48
  load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
49
49
  load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
50
- save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
51
50
 
52
51
 
53
52
  Config = TypeVar("Config", bound=CheckpointingConfig)
@@ -213,15 +212,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
213
212
  add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
214
213
  add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
215
214
 
216
- if self.config.save_tf_model:
217
- try:
218
- from jax.experimental import jax2tf
219
- except ModuleNotFoundError:
220
- raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
221
-
222
- tf_model = jax2tf.convert(model)
223
- add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
224
-
225
215
  # Updates the symlink to the new checkpoint.
226
216
  last_ckpt_path.unlink(missing_ok=True)
227
217
  try:
@@ -50,6 +50,7 @@ from xax.utils.experiments import (
50
50
  diff_configs,
51
51
  get_diff_string,
52
52
  get_git_state,
53
+ get_packages_with_versions,
53
54
  get_training_code,
54
55
  )
55
56
  from xax.utils.logging import LOG_STATUS
@@ -452,6 +453,7 @@ class TrainMixin(
452
453
  logger.log(LOG_STATUS, self.task_name)
453
454
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
454
455
  self.logger.log_file("git_state.txt", get_git_state(self))
456
+ self.logger.log_file("packages.txt", get_packages_with_versions())
455
457
  self.logger.log_file("training_code.txt", get_training_code(self))
456
458
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
457
459
 
@@ -28,6 +28,7 @@ from typing import Any, Iterator, Self, TypeVar, cast
28
28
  from urllib.parse import urlparse
29
29
 
30
30
  import git
31
+ import pkg_resources
31
32
  import requests
32
33
  from jaxtyping import Array
33
34
  from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
@@ -109,8 +110,8 @@ class StateTimer:
109
110
 
110
111
  def step(self, state: State) -> None:
111
112
  cur_time = time.time()
112
- self.step_timer.step(state.num_steps, cur_time)
113
- self.sample_timer.step(state.num_samples, cur_time)
113
+ self.step_timer.step(state.num_steps if state.phase == "train" else state.num_valid_steps, cur_time)
114
+ self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
114
115
  self.iter_timer.step(cur_time)
115
116
 
116
117
  def log_dict(self) -> dict[str, dict[str, int | float]]:
@@ -468,6 +469,16 @@ def get_git_state(obj: object) -> str:
468
469
  return traceback.format_exc()
469
470
 
470
471
 
472
+ def get_packages_with_versions() -> str:
473
+ """Gets the packages and their versions.
474
+
475
+ Returns:
476
+ A dictionary of packages and their versions.
477
+ """
478
+ packages = [(pkg.key, pkg.version) for pkg in pkg_resources.working_set]
479
+ return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
480
+
481
+
471
482
  def get_training_code(obj: object) -> str:
472
483
  """Gets the text from the file containing the provided object.
473
484
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes