xax 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +74 -2
- xax/core/conf.py +8 -33
- xax/core/state.py +13 -23
- xax/requirements.txt +2 -0
- xax/task/base.py +2 -0
- xax/task/logger.py +194 -122
- xax/task/loggers/callback.py +4 -16
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +14 -28
- xax/task/mixins/__init__.py +1 -0
- xax/task/mixins/artifacts.py +7 -4
- xax/task/mixins/checkpointing.py +12 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +16 -5
- xax/task/mixins/data_loader.py +23 -12
- xax/task/mixins/gpu_stats.py +19 -5
- xax/task/mixins/logger.py +4 -2
- xax/task/mixins/process.py +4 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +189 -129
- xax/task/script.py +1 -1
- xax/task/task.py +7 -0
- xax/utils/tensorboard.py +48 -0
- {xax-0.0.5.dist-info → xax-0.0.6.dist-info}/METADATA +12 -2
- xax-0.0.6.dist-info/RECORD +52 -0
- {xax-0.0.5.dist-info → xax-0.0.6.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.5.dist-info/RECORD +0 -52
- {xax-0.0.5.dist-info → xax-0.0.6.dist-info}/LICENSE +0 -0
- {xax-0.0.5.dist-info → xax-0.0.6.dist-info}/top_level.txt +0 -0
xax/task/launchers/staged.py
DELETED
@@ -1,29 +0,0 @@
|
|
1
|
-
"""Defines a base class with utility functions for staged training runs."""
|
2
|
-
|
3
|
-
from abc import ABC
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
from xax.task.launchers.base import BaseLauncher
|
7
|
-
from xax.task.mixins.artifacts import ArtifactsMixin, Config
|
8
|
-
|
9
|
-
|
10
|
-
class StagedLauncher(BaseLauncher, ABC):
|
11
|
-
def __init__(self, config_file_name: str = "config.yaml") -> None:
|
12
|
-
super().__init__()
|
13
|
-
|
14
|
-
self.config_file_name = config_file_name
|
15
|
-
|
16
|
-
def get_config_path(self, task: "ArtifactsMixin[Config]", use_cli: bool | list[str] = True) -> Path:
|
17
|
-
config_path = task.exp_dir / self.config_file_name
|
18
|
-
task.config.exp_dir = str(task.exp_dir)
|
19
|
-
with open(config_path, "w", encoding="utf-8") as f:
|
20
|
-
f.write(task.config_str(task.config, use_cli=use_cli))
|
21
|
-
return config_path
|
22
|
-
|
23
|
-
@classmethod
|
24
|
-
def from_components(cls, task_key: str, config_path: Path, use_cli: bool | list[str] = True) -> "ArtifactsMixin":
|
25
|
-
return (
|
26
|
-
ArtifactsMixin.from_task_key(task_key)
|
27
|
-
.get_task(config_path, use_cli=use_cli)
|
28
|
-
.set_exp_dir(config_path.parent)
|
29
|
-
)
|
xax-0.0.5.dist-info/RECORD
DELETED
@@ -1,52 +0,0 @@
|
|
1
|
-
xax/__init__.py,sha256=3OQTnHGYgaux3i9gTYZxfK8F2zS_hK2QqD-G-Z1TfHQ,7623
|
2
|
-
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
|
-
xax/requirements.txt,sha256=DRn2B9d3mAr57-U3IOIrKm2nYz8H3cYgDy6EIC3SsuE,266
|
5
|
-
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
-
xax/core/conf.py,sha256=hwgc5sJw0YRSegQLLrmIDtscev-H_a2ST1-V6BJ5aec,5915
|
7
|
-
xax/core/state.py,sha256=7lnVSytuhwPfcobPGdjfQ0QxbLgzWQNipKwXchd58QI,2695
|
8
|
-
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
-
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
|
-
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
11
|
-
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
12
|
-
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
xax/task/base.py,sha256=n82Sw-kMLr-WZzh0c_vAAQ2b-DHRYs0U8biPRonBxKU,7252
|
14
|
-
xax/task/logger.py,sha256=MAFIgd6yO0pD3gJHfKTwUDcwaM8DZD3AZtFLvrQtlFo,26740
|
15
|
-
xax/task/script.py,sha256=oBGnScYa_X284fCajabPCcbaSEIqR8nO4d40dvMv3NQ,1011
|
16
|
-
xax/task/task.py,sha256=X7TV_gt6C4m_-Il22Uyr5iMm-eh15oH5v1dl96sv1go,1295
|
17
|
-
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
18
|
-
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
19
|
-
xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
|
20
|
-
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
21
|
-
xax/task/launchers/staged.py,sha256=jYeT9u58CN4ldV-ltJiQXQglEWOnEckHWnHYjfJQaoY,1102
|
22
|
-
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
|
-
xax/task/loggers/callback.py,sha256=reaRuJs5iB6WWNgh3_tsuz_QPAlBC-5Ed2wCG_6Wj4M,2075
|
24
|
-
xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
|
25
|
-
xax/task/loggers/state.py,sha256=qyb-q8MdagN7BX-DhKucwoc45tIZJrPuvVDVoysTKC4,1576
|
26
|
-
xax/task/loggers/stdout.py,sha256=nxQXkS9JUR38RKsU9qj0dgePKguK0BFa9nl_BdGO8cE,6758
|
27
|
-
xax/task/loggers/tensorboard.py,sha256=DMYRDCQ9c-xHqO4kkZvc1-53PXCf2gX0aRiiAQDtHJ0,7293
|
28
|
-
xax/task/mixins/__init__.py,sha256=NkSAjMN5jpXE6LROIwMzX60z7UsTBpGs624_mNUWquo,745
|
29
|
-
xax/task/mixins/artifacts.py,sha256=G0984WuXII_R13IlJZn9En7iM83ISXKjeVYvn7j4wBs,3754
|
30
|
-
xax/task/mixins/checkpointing.py,sha256=JV91b5xyBUyZIbR3S-5UkBZNoAZYCnWx7Y-ayuU0lHQ,7989
|
31
|
-
xax/task/mixins/cpu_stats.py,sha256=Lqskt1t4usE6UslhANjwB0ZKOYmaC4dm9dnVKa6ERdA,8924
|
32
|
-
xax/task/mixins/data_loader.py,sha256=BPs0sYdctesnhS9nQ1rvT77MzLXznw5E4tAzWT1PpJY,5998
|
33
|
-
xax/task/mixins/gpu_stats.py,sha256=tFTNmtl9iMiLiYJSPg7gHR-ZxOP4P_ynzSmYNIAUoRw,8431
|
34
|
-
xax/task/mixins/logger.py,sha256=6XkjP_YUGY2CiDry0kDm1f9jqzJaLa1bPVYYnGjvSBU,2049
|
35
|
-
xax/task/mixins/process.py,sha256=HQAvEruvvfcS_IThrM4hKhFHZCAN2kFY_vEaZGLeZS8,1428
|
36
|
-
xax/task/mixins/runnable.py,sha256=d5-qyIpmNPtbTzE7qFJGGCPSREEDhX1VApUJPNDWye0,1933
|
37
|
-
xax/task/mixins/step_wrapper.py,sha256=Do4eGgZVuqDX9ZGDxQdfn6pRbUnHjQBAkTF0vnNH31E,1472
|
38
|
-
xax/task/mixins/train.py,sha256=Xeb0N9j-Znz5QnMDCXDGPqUSKMNLJkd8oF8giN45l2U,20099
|
39
|
-
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
40
|
-
xax/utils/experiments.py,sha256=qT3H0fyVH8DN417x7T0Xmz4SKoogW81-EHcZfyktFI8,28300
|
41
|
-
xax/utils/jax.py,sha256=VzEVB766UyH3_cgN6UP0FkCsDuGlYg5KJj8YJS4yYUk,439
|
42
|
-
xax/utils/logging.py,sha256=ST1hp2C2xntVVJBUHwo3YxPK19fBLNvHU2WGO1xqcXA,6418
|
43
|
-
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
44
|
-
xax/utils/tensorboard.py,sha256=XqxUlryFVsb75jE36uLcuoUhSr3nWg_-dzji2h6U_rI,8245
|
45
|
-
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
46
|
-
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
|
-
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
48
|
-
xax-0.0.5.dist-info/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
49
|
-
xax-0.0.5.dist-info/METADATA,sha256=VCiQmbjwZtiuORVyB0dloFTgLWtnK4o3FaolNWvf-A4,937
|
50
|
-
xax-0.0.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
51
|
-
xax-0.0.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
52
|
-
xax-0.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|