xax 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -2
- xax/task/mixins/checkpointing.py +0 -10
- xax/task/mixins/train.py +2 -0
- xax/utils/experiments.py +13 -2
- {xax-0.1.1.dist-info → xax-0.1.3.dist-info}/METADATA +1 -1
- {xax-0.1.1.dist-info → xax-0.1.3.dist-info}/RECORD +9 -9
- {xax-0.1.1.dist-info → xax-0.1.3.dist-info}/WHEEL +1 -1
- {xax-0.1.1.dist-info → xax-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.1.dist-info → xax-0.1.3.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.3"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -66,6 +66,7 @@ __all__ = [
|
|
66
66
|
"DataloaderConfig",
|
67
67
|
"GPUStatsOptions",
|
68
68
|
"StepContext",
|
69
|
+
"ValidStepTimer",
|
69
70
|
"Script",
|
70
71
|
"ScriptConfig",
|
71
72
|
"Config",
|
@@ -219,6 +220,7 @@ NAME_MAP: dict[str, str] = {
|
|
219
220
|
"DataloaderConfig": "task.mixins.data_loader",
|
220
221
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
221
222
|
"StepContext": "task.mixins.step_wrapper",
|
223
|
+
"ValidStepTimer": "task.mixins.train",
|
222
224
|
"Script": "task.script",
|
223
225
|
"ScriptConfig": "task.script",
|
224
226
|
"Config": "task.task",
|
@@ -372,7 +374,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
372
374
|
from xax.task.mixins.data_loader import DataloaderConfig
|
373
375
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
374
376
|
from xax.task.mixins.step_wrapper import StepContext
|
375
|
-
from xax.task.mixins.train import Batch, Output
|
377
|
+
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
376
378
|
from xax.task.script import Script, ScriptConfig
|
377
379
|
from xax.task.task import Config, Task
|
378
380
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -47,7 +47,6 @@ class CheckpointingConfig(ArtifactsConfig):
|
|
47
47
|
only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
|
48
48
|
load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
|
49
49
|
load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
|
50
|
-
save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
|
51
50
|
|
52
51
|
|
53
52
|
Config = TypeVar("Config", bound=CheckpointingConfig)
|
@@ -213,15 +212,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
213
212
|
add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
|
214
213
|
add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
|
215
214
|
|
216
|
-
if self.config.save_tf_model:
|
217
|
-
try:
|
218
|
-
from jax.experimental import jax2tf
|
219
|
-
except ModuleNotFoundError:
|
220
|
-
raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
|
221
|
-
|
222
|
-
tf_model = jax2tf.convert(model)
|
223
|
-
add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
|
224
|
-
|
225
215
|
# Updates the symlink to the new checkpoint.
|
226
216
|
last_ckpt_path.unlink(missing_ok=True)
|
227
217
|
try:
|
xax/task/mixins/train.py
CHANGED
@@ -50,6 +50,7 @@ from xax.utils.experiments import (
|
|
50
50
|
diff_configs,
|
51
51
|
get_diff_string,
|
52
52
|
get_git_state,
|
53
|
+
get_packages_with_versions,
|
53
54
|
get_training_code,
|
54
55
|
)
|
55
56
|
from xax.utils.logging import LOG_STATUS
|
@@ -452,6 +453,7 @@ class TrainMixin(
|
|
452
453
|
logger.log(LOG_STATUS, self.task_name)
|
453
454
|
logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
|
454
455
|
self.logger.log_file("git_state.txt", get_git_state(self))
|
456
|
+
self.logger.log_file("packages.txt", get_packages_with_versions())
|
455
457
|
self.logger.log_file("training_code.txt", get_training_code(self))
|
456
458
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
457
459
|
|
xax/utils/experiments.py
CHANGED
@@ -28,6 +28,7 @@ from typing import Any, Iterator, Self, TypeVar, cast
|
|
28
28
|
from urllib.parse import urlparse
|
29
29
|
|
30
30
|
import git
|
31
|
+
import pkg_resources
|
31
32
|
import requests
|
32
33
|
from jaxtyping import Array
|
33
34
|
from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
|
@@ -109,8 +110,8 @@ class StateTimer:
|
|
109
110
|
|
110
111
|
def step(self, state: State) -> None:
|
111
112
|
cur_time = time.time()
|
112
|
-
self.step_timer.step(state.num_steps, cur_time)
|
113
|
-
self.sample_timer.step(state.num_samples, cur_time)
|
113
|
+
self.step_timer.step(state.num_steps if state.phase == "train" else state.num_valid_steps, cur_time)
|
114
|
+
self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
|
114
115
|
self.iter_timer.step(cur_time)
|
115
116
|
|
116
117
|
def log_dict(self) -> dict[str, dict[str, int | float]]:
|
@@ -468,6 +469,16 @@ def get_git_state(obj: object) -> str:
|
|
468
469
|
return traceback.format_exc()
|
469
470
|
|
470
471
|
|
472
|
+
def get_packages_with_versions() -> str:
|
473
|
+
"""Gets the packages and their versions.
|
474
|
+
|
475
|
+
Returns:
|
476
|
+
A dictionary of packages and their versions.
|
477
|
+
"""
|
478
|
+
packages = [(pkg.key, pkg.version) for pkg in pkg_resources.working_set]
|
479
|
+
return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
|
480
|
+
|
481
|
+
|
471
482
|
def get_training_code(obj: object) -> str:
|
472
483
|
"""Gets the text from the file containing the provided object.
|
473
484
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=n7EXl0pwEPzlw2DjS-3ePgx0VoQnMDnHLVc5exkHGcM,13361
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -30,7 +30,7 @@ xax/task/loggers/stdout.py,sha256=bR0k-PfmFgLfPxLPb4hZw_8G_msA32UeHfAAu11nEYs,67
|
|
30
30
|
xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
|
31
31
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
32
32
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
33
|
-
xax/task/mixins/checkpointing.py,sha256=
|
33
|
+
xax/task/mixins/checkpointing.py,sha256=a6tVyISsDIz68rrhb1rAh3rjQlqkDVJCmSBmETQrnRM,8480
|
34
34
|
xax/task/mixins/compile.py,sha256=FRsxwLnZjjxpeWJ7Bx_d8XUY50oDoGidgpeRt4ejeQk,3377
|
35
35
|
xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
|
36
36
|
xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
|
@@ -39,10 +39,10 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
39
39
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
40
40
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
41
41
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
42
|
-
xax/task/mixins/train.py,sha256=
|
42
|
+
xax/task/mixins/train.py,sha256=8AaBXaopnrxtSZXldyFCE3QX1k5r3IsZMr6O0ICnNnU,22332
|
43
43
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
|
-
xax/utils/experiments.py,sha256=
|
45
|
+
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
46
46
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
47
47
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
48
48
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
@@ -53,8 +53,8 @@ xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,1617
|
|
53
53
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
54
54
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
55
55
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
56
|
-
xax-0.1.
|
57
|
-
xax-0.1.
|
58
|
-
xax-0.1.
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
56
|
+
xax-0.1.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
57
|
+
xax-0.1.3.dist-info/METADATA,sha256=m3AyjlRD9C-O2Tp5zH5i5TbEL7bZooeIpypUCYuYPtQ,1877
|
58
|
+
xax-0.1.3.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
|
59
|
+
xax-0.1.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
60
|
+
xax-0.1.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|