xax 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/core/state.py +2 -2
- xax/nn/export.py +2 -2
- xax/task/logger.py +2 -1
- xax/task/mixins/train.py +8 -5
- {xax-0.2.4.dist-info → xax-0.2.6.dist-info}/METADATA +6 -7
- {xax-0.2.4.dist-info → xax-0.2.6.dist-info}/RECORD +10 -10
- {xax-0.2.4.dist-info → xax-0.2.6.dist-info}/WHEEL +0 -0
- {xax-0.2.4.dist-info → xax-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.4.dist-info → xax-0.2.6.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/core/state.py
CHANGED
@@ -89,9 +89,9 @@ class State:
|
|
89
89
|
int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
|
90
90
|
|
91
91
|
if "phase" in kwargs:
|
92
|
-
int32_arr = int32_arr.at[
|
92
|
+
int32_arr = int32_arr.at[2].set(_phase_to_int(kwargs["phase"]))
|
93
93
|
if "_phase" in kwargs:
|
94
|
-
int32_arr = int32_arr.at[
|
94
|
+
int32_arr = int32_arr.at[2].set(kwargs["_phase"])
|
95
95
|
|
96
96
|
if "num_samples" in kwargs:
|
97
97
|
float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
|
xax/nn/export.py
CHANGED
@@ -14,8 +14,8 @@ try:
|
|
14
14
|
from orbax.export import ExportManager, JaxModule, ServingConfig
|
15
15
|
except ImportError as e:
|
16
16
|
raise ImportError(
|
17
|
-
"In order to export models, please install Xax with
|
18
|
-
"using 'xax[
|
17
|
+
"In order to export models, please install Xax with exportable dependencies, "
|
18
|
+
"using 'xax[exportable]` to install the required dependencies."
|
19
19
|
) from e
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
xax/task/logger.py
CHANGED
@@ -521,7 +521,8 @@ class LoggerImpl(ABC):
|
|
521
521
|
Returns:
|
522
522
|
If the logger should log the current step.
|
523
523
|
"""
|
524
|
-
|
524
|
+
elapsed_time = state.elapsed_time_s.item() if state.phase == "train" else state.valid_elapsed_time_s.item()
|
525
|
+
return self.tickers[state.phase].tick(elapsed_time)
|
525
526
|
|
526
527
|
|
527
528
|
class ToastHandler(logging.Handler):
|
xax/task/mixins/train.py
CHANGED
@@ -115,19 +115,22 @@ class ValidStepTimer:
|
|
115
115
|
self.last_valid_time: float | None = None
|
116
116
|
self.last_valid_step: int | None = None
|
117
117
|
|
118
|
+
def _reset(self, state: State) -> None:
|
119
|
+
self.last_valid_time = state.elapsed_time_s.item()
|
120
|
+
self.last_valid_step = state.num_steps.item()
|
121
|
+
|
118
122
|
def is_valid_step(self, state: State) -> bool:
|
119
123
|
if state.num_steps < self.valid_first_n_steps:
|
120
124
|
return True
|
121
125
|
|
122
126
|
if self.last_valid_time is None or self.last_valid_step is None:
|
123
|
-
self.
|
124
|
-
self.last_valid_step = state.num_steps.item()
|
127
|
+
self._reset(state)
|
125
128
|
return False
|
126
129
|
|
127
130
|
# Step-based validation.
|
128
131
|
valid_every_n_steps = self.valid_every_n_steps
|
129
132
|
if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
|
130
|
-
self.
|
133
|
+
self._reset(state)
|
131
134
|
return True
|
132
135
|
|
133
136
|
# Time-based validation.
|
@@ -136,14 +139,14 @@ class ValidStepTimer:
|
|
136
139
|
valid_every_n_seconds is not None
|
137
140
|
and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
|
138
141
|
):
|
139
|
-
self.
|
142
|
+
self._reset(state)
|
140
143
|
return True
|
141
144
|
|
142
145
|
# Time-based validation for first validation step.
|
143
146
|
if self.first_valid_step_flag:
|
144
147
|
valid_first_n_seconds = self.valid_first_n_seconds
|
145
148
|
if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
|
146
|
-
self.
|
149
|
+
self._reset(state)
|
147
150
|
self.first_valid_step_flag = False
|
148
151
|
return True
|
149
152
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.6
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -31,11 +31,10 @@ Requires-Dist: pytest; extra == "dev"
|
|
31
31
|
Requires-Dist: types-pillow; extra == "dev"
|
32
32
|
Requires-Dist: types-psutil; extra == "dev"
|
33
33
|
Requires-Dist: types-requests; extra == "dev"
|
34
|
-
Provides-Extra:
|
35
|
-
Requires-Dist:
|
36
|
-
Requires-Dist:
|
37
|
-
|
38
|
-
Requires-Dist: flax; extra == "flax"
|
34
|
+
Provides-Extra: exportable
|
35
|
+
Requires-Dist: flax; extra == "exportable"
|
36
|
+
Requires-Dist: orbax-export; extra == "exportable"
|
37
|
+
Requires-Dist: tensorflow; extra == "exportable"
|
39
38
|
Provides-Extra: all
|
40
39
|
Requires-Dist: black; extra == "all"
|
41
40
|
Requires-Dist: darglint; extra == "all"
|
@@ -45,9 +44,9 @@ Requires-Dist: pytest; extra == "all"
|
|
45
44
|
Requires-Dist: types-pillow; extra == "all"
|
46
45
|
Requires-Dist: types-psutil; extra == "all"
|
47
46
|
Requires-Dist: types-requests; extra == "all"
|
47
|
+
Requires-Dist: flax; extra == "all"
|
48
48
|
Requires-Dist: orbax-export; extra == "all"
|
49
49
|
Requires-Dist: tensorflow; extra == "all"
|
50
|
-
Requires-Dist: flax; extra == "all"
|
51
50
|
Dynamic: author
|
52
51
|
Dynamic: description
|
53
52
|
Dynamic: description-content-type
|
@@ -1,14 +1,14 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=pvOmycOUli9h53XMTlinjJdY4zzw_fvWoF05UQwoItI,14225
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
7
|
-
xax/core/state.py,sha256=
|
7
|
+
xax/core/state.py,sha256=yO25lMoLCUTJlHyLzQxlDbsHC_GZ3HkrKAq5huA7AkU,4552
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
11
|
-
xax/nn/export.py,sha256=
|
11
|
+
xax/nn/export.py,sha256=pRfM2B4hB2EvljysC6AjtgB_7Cn7JtaP3dhYU2stZtY,5545
|
12
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
13
13
|
xax/nn/geom.py,sha256=rImNlkHWeoNcY7f84nknizJ6uzsrMhbAtKeb2xAWxNY,6215
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
@@ -17,7 +17,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
xax/task/base.py,sha256=OnXi2hiKPGwt6ng1dutnoQSiw7lEiWFlC_vx99_JsbQ,7694
|
20
|
-
xax/task/logger.py,sha256=
|
20
|
+
xax/task/logger.py,sha256=gE67AaPCfU_1FpxY3t0yNRrIVgtp8Sax9UyOqFYMtzM,40976
|
21
21
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
22
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
23
23
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=X3DIMPqpMRAJvFeeDaw5qFkH70R6yJelBAD8KHHwkUc,31196
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
47
|
xax/utils/experiments.py,sha256=d2H63ECtVOKySMUMrQRqq4kcuZpoXqo-L931usDVAhE,29903
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.6.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.6.dist-info/METADATA,sha256=g--2dJ3PIYDOWURjNhVmkv5B_VUcI2xY0iBa4oGuuVc,1879
|
63
|
+
xax-0.2.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.2.6.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|