xax 0.2.17__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.17"
15
+ __version__ = "0.2.18"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -6,7 +6,7 @@ import logging
6
6
  import tarfile
7
7
  from dataclasses import dataclass
8
8
  from pathlib import Path
9
- from typing import Generic, Literal, Sequence, TypeVar, cast, overload
9
+ from typing import Generic, Literal, Self, Sequence, TypeVar, cast, overload
10
10
 
11
11
  import equinox as eqx
12
12
  import jax
@@ -46,7 +46,6 @@ class CheckpointingConfig(ArtifactsConfig):
46
46
  save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
47
47
  only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
48
48
  load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
49
- load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
50
49
 
51
50
 
52
51
  Config = TypeVar("Config", bound=CheckpointingConfig)
@@ -306,3 +305,11 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
306
305
  self.on_after_checkpoint_save(ckpt_path, state)
307
306
 
308
307
  return ckpt_path
308
+
309
+ @classmethod
310
+ def load_config(cls, ckpt_path: str | Path) -> Config:
311
+ return cls.get_config(load_ckpt(Path(ckpt_path), part="config"), use_cli=False)
312
+
313
+ @classmethod
314
+ def load_task(cls, ckpt_path: str | Path) -> Self:
315
+ return cls(cls.load_config(ckpt_path))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.17
3
+ Version: 0.2.18
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=dBxZ_r1ck3C9ZH9VRM38-ApVkyUe7CnI_0SF9k07KcI,15733
1
+ xax/__init__.py,sha256=ru-amUXLIpFI1cE7NfkPpff6NnuLbSh6LHbJkmvH2zM,15733
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -32,7 +32,7 @@ xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,66
32
32
  xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=Ma7fwsp-SA1w6GcuBSskszj5TB83yxYJm4Ns_EnqkI4,3018
35
- xax/task/mixins/checkpointing.py,sha256=ypdXvC6oJlsUGm4PiTJWXrtTi9w0K9IpoO0-8gM1hZ4,11295
35
+ xax/task/mixins/checkpointing.py,sha256=w-xsx8tH8PPYRf9eoqvz-4kAdunHMvazcoZAIj6niyI,11468
36
36
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
37
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
38
38
  xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.17.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.17.dist-info/METADATA,sha256=GQhyzReeHSrZkYrpxeSXt19z2271zD49-S6fwN6cagU,1880
63
- xax-0.2.17.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
64
- xax-0.2.17.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.17.dist-info/RECORD,,
61
+ xax-0.2.18.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.18.dist-info/METADATA,sha256=phU79eE74vMM79HouJP9iQgTDONJ59no3ktUl5GHF78,1880
63
+ xax-0.2.18.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
64
+ xax-0.2.18.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.18.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.0.0)
2
+ Generator: setuptools (80.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5