xax 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.1"
15
+ __version__ = "0.1.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -66,6 +66,7 @@ __all__ = [
66
66
  "DataloaderConfig",
67
67
  "GPUStatsOptions",
68
68
  "StepContext",
69
+ "ValidStepTimer",
69
70
  "Script",
70
71
  "ScriptConfig",
71
72
  "Config",
@@ -219,6 +220,7 @@ NAME_MAP: dict[str, str] = {
219
220
  "DataloaderConfig": "task.mixins.data_loader",
220
221
  "GPUStatsOptions": "task.mixins.gpu_stats",
221
222
  "StepContext": "task.mixins.step_wrapper",
223
+ "ValidStepTimer": "task.mixins.train",
222
224
  "Script": "task.script",
223
225
  "ScriptConfig": "task.script",
224
226
  "Config": "task.task",
@@ -372,7 +374,7 @@ if IMPORT_ALL or TYPE_CHECKING:
372
374
  from xax.task.mixins.data_loader import DataloaderConfig
373
375
  from xax.task.mixins.gpu_stats import GPUStatsOptions
374
376
  from xax.task.mixins.step_wrapper import StepContext
375
- from xax.task.mixins.train import Batch, Output
377
+ from xax.task.mixins.train import Batch, Output, ValidStepTimer
376
378
  from xax.task.script import Script, ScriptConfig
377
379
  from xax.task.task import Config, Task
378
380
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
@@ -47,7 +47,6 @@ class CheckpointingConfig(ArtifactsConfig):
47
47
  only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
48
48
  load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
49
49
  load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
50
- save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
51
50
 
52
51
 
53
52
  Config = TypeVar("Config", bound=CheckpointingConfig)
@@ -213,15 +212,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
213
212
  add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
214
213
  add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
215
214
 
216
- if self.config.save_tf_model:
217
- try:
218
- from jax.experimental import jax2tf
219
- except ModuleNotFoundError:
220
- raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
221
-
222
- tf_model = jax2tf.convert(model)
223
- add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
224
-
225
215
  # Updates the symlink to the new checkpoint.
226
216
  last_ckpt_path.unlink(missing_ok=True)
227
217
  try:
xax/task/mixins/train.py CHANGED
@@ -50,6 +50,7 @@ from xax.utils.experiments import (
50
50
  diff_configs,
51
51
  get_diff_string,
52
52
  get_git_state,
53
+ get_packages_with_versions,
53
54
  get_training_code,
54
55
  )
55
56
  from xax.utils.logging import LOG_STATUS
@@ -452,6 +453,7 @@ class TrainMixin(
452
453
  logger.log(LOG_STATUS, self.task_name)
453
454
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
454
455
  self.logger.log_file("git_state.txt", get_git_state(self))
456
+ self.logger.log_file("packages.txt", get_packages_with_versions())
455
457
  self.logger.log_file("training_code.txt", get_training_code(self))
456
458
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
457
459
 
xax/utils/experiments.py CHANGED
@@ -28,6 +28,7 @@ from typing import Any, Iterator, Self, TypeVar, cast
28
28
  from urllib.parse import urlparse
29
29
 
30
30
  import git
31
+ import pkg_resources
31
32
  import requests
32
33
  from jaxtyping import Array
33
34
  from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
@@ -109,8 +110,8 @@ class StateTimer:
109
110
 
110
111
  def step(self, state: State) -> None:
111
112
  cur_time = time.time()
112
- self.step_timer.step(state.num_steps, cur_time)
113
- self.sample_timer.step(state.num_samples, cur_time)
113
+ self.step_timer.step(state.num_steps if state.phase == "train" else state.num_valid_steps, cur_time)
114
+ self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
114
115
  self.iter_timer.step(cur_time)
115
116
 
116
117
  def log_dict(self) -> dict[str, dict[str, int | float]]:
@@ -468,6 +469,16 @@ def get_git_state(obj: object) -> str:
468
469
  return traceback.format_exc()
469
470
 
470
471
 
472
+ def get_packages_with_versions() -> str:
473
+ """Gets the packages and their versions.
474
+
475
+ Returns:
476
+ A dictionary of packages and their versions.
477
+ """
478
+ packages = [(pkg.key, pkg.version) for pkg in pkg_resources.working_set]
479
+ return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
480
+
481
+
471
482
  def get_training_code(obj: object) -> str:
472
483
  """Gets the text from the file containing the provided object.
473
484
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=JyKRACir9b0bkuG93bwxADFrVr-Lo76kenDBJtvb_wQ,13280
1
+ xax/__init__.py,sha256=n7EXl0pwEPzlw2DjS-3ePgx0VoQnMDnHLVc5exkHGcM,13361
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -30,7 +30,7 @@ xax/task/loggers/stdout.py,sha256=bR0k-PfmFgLfPxLPb4hZw_8G_msA32UeHfAAu11nEYs,67
30
30
  xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
31
31
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
32
32
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
33
- xax/task/mixins/checkpointing.py,sha256=sRkVxJbQfqDf1-lp1KFrAGYWHhTlV8_DORxGQ_69P1A,8954
33
+ xax/task/mixins/checkpointing.py,sha256=a6tVyISsDIz68rrhb1rAh3rjQlqkDVJCmSBmETQrnRM,8480
34
34
  xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
35
35
  xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
36
36
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
@@ -39,10 +39,10 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
39
39
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
40
40
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
41
41
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
42
- xax/task/mixins/train.py,sha256=BEC7HSwBlGZDe7jCsedqEA8-K1Zx52-bTjsBONYIE5g,22225
42
+ xax/task/mixins/train.py,sha256=8AaBXaopnrxtSZXldyFCE3QX1k5r3IsZMr6O0ICnNnU,22332
43
43
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
44
44
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
45
- xax/utils/experiments.py,sha256=_cwoBaiBxoQ_Tstm0rz7TEqfELqcktmPflb6AP1K0qA,28779
45
+ xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
46
46
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
47
47
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
48
48
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
@@ -53,8 +53,8 @@ xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,1617
53
53
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
54
54
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
55
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
56
- xax-0.1.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
57
- xax-0.1.1.dist-info/METADATA,sha256=tJ4ilL3uBbykHBQTHbh-bN6m4hrHqivyyFeuI33ddX4,1877
58
- xax-0.1.1.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
59
- xax-0.1.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
60
- xax-0.1.1.dist-info/RECORD,,
56
+ xax-0.1.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
57
+ xax-0.1.3.dist-info/METADATA,sha256=m3AyjlRD9C-O2Tp5zH5i5TbEL7bZooeIpypUCYuYPtQ,1877
58
+ xax-0.1.3.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
59
+ xax-0.1.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
60
+ xax-0.1.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (77.0.3)
2
+ Generator: setuptools (78.0.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5