xax 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -2
- xax/utils/experiments.py +2 -2
- {xax-0.1.1.dist-info → xax-0.1.2.dist-info}/METADATA +1 -1
- {xax-0.1.1.dist-info → xax-0.1.2.dist-info}/RECORD +7 -7
- {xax-0.1.1.dist-info → xax-0.1.2.dist-info}/WHEEL +1 -1
- {xax-0.1.1.dist-info → xax-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.1.dist-info → xax-0.1.2.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.2"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -66,6 +66,7 @@ __all__ = [
|
|
66
66
|
"DataloaderConfig",
|
67
67
|
"GPUStatsOptions",
|
68
68
|
"StepContext",
|
69
|
+
"ValidStepTimer",
|
69
70
|
"Script",
|
70
71
|
"ScriptConfig",
|
71
72
|
"Config",
|
@@ -219,6 +220,7 @@ NAME_MAP: dict[str, str] = {
|
|
219
220
|
"DataloaderConfig": "task.mixins.data_loader",
|
220
221
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
221
222
|
"StepContext": "task.mixins.step_wrapper",
|
223
|
+
"ValidStepTimer": "task.mixins.train",
|
222
224
|
"Script": "task.script",
|
223
225
|
"ScriptConfig": "task.script",
|
224
226
|
"Config": "task.task",
|
@@ -372,7 +374,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
372
374
|
from xax.task.mixins.data_loader import DataloaderConfig
|
373
375
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
374
376
|
from xax.task.mixins.step_wrapper import StepContext
|
375
|
-
from xax.task.mixins.train import Batch, Output
|
377
|
+
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
376
378
|
from xax.task.script import Script, ScriptConfig
|
377
379
|
from xax.task.task import Config, Task
|
378
380
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
xax/utils/experiments.py
CHANGED
@@ -109,8 +109,8 @@ class StateTimer:
|
|
109
109
|
|
110
110
|
def step(self, state: State) -> None:
|
111
111
|
cur_time = time.time()
|
112
|
-
self.step_timer.step(state.num_steps, cur_time)
|
113
|
-
self.sample_timer.step(state.num_samples, cur_time)
|
112
|
+
self.step_timer.step(state.num_steps if state.phase == "train" else state.num_valid_steps, cur_time)
|
113
|
+
self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
|
114
114
|
self.iter_timer.step(cur_time)
|
115
115
|
|
116
116
|
def log_dict(self) -> dict[str, dict[str, int | float]]:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=Ti6hrfoY5wnywzOvkvtCwq2SvLsjYfbm_6U_UzYakls,13361
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
@@ -42,7 +42,7 @@ xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2
|
|
42
42
|
xax/task/mixins/train.py,sha256=BEC7HSwBlGZDe7jCsedqEA8-K1Zx52-bTjsBONYIE5g,22225
|
43
43
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
|
-
xax/utils/experiments.py,sha256=
|
45
|
+
xax/utils/experiments.py,sha256=d-e-RCw9PlnuqV3FPW0U74zcvlOKV48lUrX8tvAfhew,28887
|
46
46
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
47
47
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
48
48
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
@@ -53,8 +53,8 @@ xax/utils/tensorboard.py,sha256=_S70dS69pduiD05viHAGgYGsaBry1QL2ej6ZwUIXPOE,1617
|
|
53
53
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
54
54
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
55
55
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
56
|
-
xax-0.1.
|
57
|
-
xax-0.1.
|
58
|
-
xax-0.1.
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
56
|
+
xax-0.1.2.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
57
|
+
xax-0.1.2.dist-info/METADATA,sha256=-BB6_Qiip_pPkf96Wl9FZsM_7PPKJr5l8v2owrXCvoI,1877
|
58
|
+
xax-0.1.2.dist-info/WHEEL,sha256=L0N565qmK-3nM2eBoMNFszYJ_MTx03_tQ0CQu1bHLYo,91
|
59
|
+
xax-0.1.2.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
60
|
+
xax-0.1.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|