xax 0.2.17__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/task/mixins/checkpointing.py +9 -2
- {xax-0.2.17.dist-info → xax-0.2.18.dist-info}/METADATA +1 -1
- {xax-0.2.17.dist-info → xax-0.2.18.dist-info}/RECORD +7 -7
- {xax-0.2.17.dist-info → xax-0.2.18.dist-info}/WHEEL +1 -1
- {xax-0.2.17.dist-info → xax-0.2.18.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.17.dist-info → xax-0.2.18.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/task/mixins/checkpointing.py
CHANGED
@@ -6,7 +6,7 @@ import logging
|
|
6
6
|
import tarfile
|
7
7
|
from dataclasses import dataclass
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import Generic, Literal, Sequence, TypeVar, cast, overload
|
9
|
+
from typing import Generic, Literal, Self, Sequence, TypeVar, cast, overload
|
10
10
|
|
11
11
|
import equinox as eqx
|
12
12
|
import jax
|
@@ -46,7 +46,6 @@ class CheckpointingConfig(ArtifactsConfig):
|
|
46
46
|
save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
|
47
47
|
only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
|
48
48
|
load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
|
49
|
-
load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
|
50
49
|
|
51
50
|
|
52
51
|
Config = TypeVar("Config", bound=CheckpointingConfig)
|
@@ -306,3 +305,11 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
306
305
|
self.on_after_checkpoint_save(ckpt_path, state)
|
307
306
|
|
308
307
|
return ckpt_path
|
308
|
+
|
309
|
+
@classmethod
|
310
|
+
def load_config(cls, ckpt_path: str | Path) -> Config:
|
311
|
+
return cls.get_config(load_ckpt(Path(ckpt_path), part="config"), use_cli=False)
|
312
|
+
|
313
|
+
@classmethod
|
314
|
+
def load_task(cls, ckpt_path: str | Path) -> Self:
|
315
|
+
return cls(cls.load_config(ckpt_path))
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=ru-amUXLIpFI1cE7NfkPpff6NnuLbSh6LHbJkmvH2zM,15733
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -32,7 +32,7 @@ xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,66
|
|
32
32
|
xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=Ma7fwsp-SA1w6GcuBSskszj5TB83yxYJm4Ns_EnqkI4,3018
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=w-xsx8tH8PPYRf9eoqvz-4kAdunHMvazcoZAIj6niyI,11468
|
36
36
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
37
|
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
38
38
|
xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.18.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.18.dist-info/METADATA,sha256=phU79eE74vMM79HouJP9iQgTDONJ59no3ktUl5GHF78,1880
|
63
|
+
xax-0.2.18.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
|
64
|
+
xax-0.2.18.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|