xax 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/core/state.py +21 -11
- xax/task/loggers/tensorboard.py +1 -1
- xax/task/mixins/train.py +29 -15
- {xax-0.2.11.dist-info → xax-0.2.13.dist-info}/METADATA +1 -1
- {xax-0.2.11.dist-info → xax-0.2.13.dist-info}/RECORD +9 -9
- {xax-0.2.11.dist-info → xax-0.2.13.dist-info}/WHEEL +0 -0
- {xax-0.2.11.dist-info → xax-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.11.dist-info → xax-0.2.13.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/core/state.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Defines a dataclass for keeping track of the current training state."""
|
2
2
|
|
3
|
+
import time
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from typing import Literal, NotRequired, TypedDict, Unpack, cast
|
5
6
|
|
@@ -28,6 +29,7 @@ class StateDict(TypedDict, total=False):
|
|
28
29
|
num_samples: NotRequired[int | Array]
|
29
30
|
num_valid_steps: NotRequired[int | Array]
|
30
31
|
num_valid_samples: NotRequired[int | Array]
|
32
|
+
start_time_s: NotRequired[float | Array]
|
31
33
|
elapsed_time_s: NotRequired[float | Array]
|
32
34
|
valid_elapsed_time_s: NotRequired[float | Array]
|
33
35
|
phase: NotRequired[Phase]
|
@@ -57,13 +59,17 @@ class State:
|
|
57
59
|
return self._float32_arr[1]
|
58
60
|
|
59
61
|
@property
|
60
|
-
def
|
62
|
+
def start_time_s(self) -> Array:
|
61
63
|
return self._float32_arr[2]
|
62
64
|
|
63
65
|
@property
|
64
|
-
def
|
66
|
+
def elapsed_time_s(self) -> Array:
|
65
67
|
return self._float32_arr[3]
|
66
68
|
|
69
|
+
@property
|
70
|
+
def valid_elapsed_time_s(self) -> Array:
|
71
|
+
return self._float32_arr[4]
|
72
|
+
|
67
73
|
@property
|
68
74
|
def phase(self) -> Phase:
|
69
75
|
return _int_to_phase(self._int32_arr[2].item())
|
@@ -72,7 +78,7 @@ class State:
|
|
72
78
|
def init_state(cls) -> "State":
|
73
79
|
return cls(
|
74
80
|
_int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
|
75
|
-
_float32_arr=jnp.array([0.0, 0.0, 0.0, 0.0], dtype=jnp.float32),
|
81
|
+
_float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0, 0.0], dtype=jnp.float32),
|
76
82
|
)
|
77
83
|
|
78
84
|
@property
|
@@ -98,10 +104,12 @@ class State:
|
|
98
104
|
if "num_valid_samples" in kwargs:
|
99
105
|
float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
|
100
106
|
|
107
|
+
if "start_time_s" in kwargs:
|
108
|
+
float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
|
101
109
|
if "elapsed_time_s" in kwargs:
|
102
|
-
float32_arr = float32_arr.at[
|
110
|
+
float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
|
103
111
|
if "valid_elapsed_time_s" in kwargs:
|
104
|
-
float32_arr = float32_arr.at[
|
112
|
+
float32_arr = float32_arr.at[4].set(kwargs["valid_elapsed_time_s"])
|
105
113
|
|
106
114
|
return State(
|
107
115
|
_int32_arr=int32_arr,
|
@@ -110,12 +118,13 @@ class State:
|
|
110
118
|
|
111
119
|
def to_dict(self) -> dict[str, int | float | str]:
|
112
120
|
return {
|
113
|
-
"num_steps": int(self.num_steps),
|
114
|
-
"num_valid_steps": int(self.num_valid_steps),
|
115
|
-
"num_samples": int(self.num_samples),
|
116
|
-
"num_valid_samples": int(self.num_valid_samples),
|
117
|
-
"
|
118
|
-
"
|
121
|
+
"num_steps": int(self.num_steps.item()),
|
122
|
+
"num_valid_steps": int(self.num_valid_steps.item()),
|
123
|
+
"num_samples": int(self.num_samples.item()),
|
124
|
+
"num_valid_samples": int(self.num_valid_samples.item()),
|
125
|
+
"start_time_s": float(self.start_time_s.item()),
|
126
|
+
"elapsed_time_s": float(self.elapsed_time_s.item()),
|
127
|
+
"valid_elapsed_time_s": float(self.valid_elapsed_time_s.item()),
|
119
128
|
"phase": str(self.phase),
|
120
129
|
}
|
121
130
|
|
@@ -137,6 +146,7 @@ class State:
|
|
137
146
|
[
|
138
147
|
d.get("num_samples", 0),
|
139
148
|
d.get("num_valid_samples", 0),
|
149
|
+
d.get("start_time_s", time.time()),
|
140
150
|
d.get("elapsed_time_s", 0.0),
|
141
151
|
d.get("valid_elapsed_time_s", 0.0),
|
142
152
|
],
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -160,7 +160,7 @@ class TensorboardLogger(LoggerImpl):
|
|
160
160
|
writer = self.get_writer(line.state.phase)
|
161
161
|
|
162
162
|
global_step = line.state.num_steps.item()
|
163
|
-
walltime = line.state.elapsed_time_s.item()
|
163
|
+
walltime = line.state.start_time_s.item() + line.state.elapsed_time_s.item()
|
164
164
|
|
165
165
|
for namespace, scalars in line.scalars.items():
|
166
166
|
for scalar_key, scalar_value in scalars.items():
|
xax/task/mixins/train.py
CHANGED
@@ -120,7 +120,7 @@ class ValidStepTimer:
|
|
120
120
|
self.last_valid_time = state.elapsed_time_s.item()
|
121
121
|
self.last_valid_step = state.num_steps.item()
|
122
122
|
|
123
|
-
def
|
123
|
+
def __call__(self, state: State) -> bool:
|
124
124
|
if state.num_steps < self.valid_first_n_steps and state.num_valid_steps < self.valid_first_n_steps:
|
125
125
|
return True
|
126
126
|
|
@@ -130,15 +130,18 @@ class ValidStepTimer:
|
|
130
130
|
|
131
131
|
# Step-based validation.
|
132
132
|
valid_every_n_steps = self.valid_every_n_steps
|
133
|
-
if valid_every_n_steps is not None and
|
133
|
+
if valid_every_n_steps is not None and (
|
134
|
+
state.num_steps >= valid_every_n_steps + self.last_valid_step
|
135
|
+
or state.num_valid_steps >= valid_every_n_steps + self.last_valid_step
|
136
|
+
):
|
134
137
|
self._reset(state)
|
135
138
|
return True
|
136
139
|
|
137
140
|
# Time-based validation.
|
138
141
|
valid_every_n_seconds = self.valid_every_n_seconds
|
139
|
-
if (
|
140
|
-
|
141
|
-
|
142
|
+
if valid_every_n_seconds is not None and (
|
143
|
+
state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
|
144
|
+
or state.valid_elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
|
142
145
|
):
|
143
146
|
self._reset(state)
|
144
147
|
return True
|
@@ -146,7 +149,10 @@ class ValidStepTimer:
|
|
146
149
|
# Time-based validation for first validation step.
|
147
150
|
if self.first_valid_step_flag:
|
148
151
|
valid_first_n_seconds = self.valid_first_n_seconds
|
149
|
-
if valid_first_n_seconds is not None and
|
152
|
+
if valid_first_n_seconds is not None and (
|
153
|
+
state.elapsed_time_s.item() >= valid_first_n_seconds
|
154
|
+
or state.valid_elapsed_time_s.item() >= valid_first_n_seconds
|
155
|
+
):
|
150
156
|
self._reset(state)
|
151
157
|
self.first_valid_step_flag = False
|
152
158
|
return True
|
@@ -722,21 +728,27 @@ class TrainMixin(
|
|
722
728
|
model_arr, model_static = eqx.partition(model, self.model_partition_fn)
|
723
729
|
|
724
730
|
while not self.is_training_over(state):
|
725
|
-
|
731
|
+
valid_step = self.valid_step_timer(state)
|
732
|
+
|
733
|
+
if valid_step:
|
726
734
|
with ContextTimer() as timer:
|
735
|
+
state = state.replace(phase="valid")
|
727
736
|
valid_batch = next(valid_pf)
|
728
737
|
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
729
738
|
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
730
739
|
|
740
|
+
state = state.replace(
|
741
|
+
num_valid_steps=state.num_valid_steps + 1,
|
742
|
+
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
743
|
+
)
|
744
|
+
|
731
745
|
state = state.replace(
|
732
|
-
phase="valid",
|
733
|
-
num_valid_steps=state.num_valid_steps + 1,
|
734
|
-
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
735
746
|
valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
|
736
747
|
)
|
737
748
|
|
738
749
|
with ContextTimer() as timer:
|
739
750
|
state = self.on_step_start(state)
|
751
|
+
state = state.replace(phase="train")
|
740
752
|
train_batch = next(train_pf)
|
741
753
|
model_arr, opt_state, output, metrics = self.train_step(
|
742
754
|
model_arr=model_arr,
|
@@ -748,15 +760,17 @@ class TrainMixin(
|
|
748
760
|
)
|
749
761
|
self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
|
750
762
|
|
763
|
+
state = state.replace(
|
764
|
+
num_steps=state.num_steps + 1,
|
765
|
+
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
766
|
+
)
|
767
|
+
|
768
|
+
state = self.on_step_end(state)
|
769
|
+
|
751
770
|
state = state.replace(
|
752
|
-
phase="train",
|
753
|
-
num_steps=state.num_steps + 1,
|
754
|
-
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
755
771
|
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
756
772
|
)
|
757
773
|
|
758
|
-
state = self.on_step_end(state)
|
759
|
-
|
760
774
|
if self.should_checkpoint(state):
|
761
775
|
model = eqx.combine(model_arr, model_static)
|
762
776
|
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
@@ -1,10 +1,10 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=33wIwGeXDFReg2ZnFqUHfSybj5cKyMqnI8ncj8-9yVg,15510
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
7
|
-
xax/core/state.py,sha256=
|
7
|
+
xax/core/state.py,sha256=KsNMnM_RgsZ2Ntc2pp4Fi6zG4rZb_89-kqmyGxDvyRg,4974
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
10
|
xax/nn/equinox.py,sha256=JZuSApD4bL0UK5W1nrQtucWYvNWUha07J6LTLk_RX-Y,4910
|
@@ -29,7 +29,7 @@ xax/task/loggers/callback.py,sha256=zQuV1xCvz47Q3UQqP1D5mBhbVzptvmPR_7hX25vqSk0,
|
|
29
29
|
xax/task/loggers/json.py,sha256=6A5wL7kspsXnpPhI_vu0scgd2Z2-WLhw4gbBFm7eZMM,4377
|
30
30
|
xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,1407
|
31
31
|
xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
|
32
|
-
xax/task/loggers/tensorboard.py,sha256=
|
32
|
+
xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=Ma7fwsp-SA1w6GcuBSskszj5TB83yxYJm4Ns_EnqkI4,3018
|
35
35
|
xax/task/mixins/checkpointing.py,sha256=zqospBFnTbGt_iriiduVfXazINPbzWpwmIs91KAniMY,10147
|
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=_QoxSDMW6nmpH82Un2LDsVIBg9KIx8npRwSjY4TEGYA,31830
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
47
|
xax/utils/experiments.py,sha256=bj8BftSHT3fFzfiJ0Co0WvqWo0rUS8kQnQYpVvH8FTM,29942
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.13.dist-info/METADATA,sha256=-foHRw3ph7yxBmMmjO_oqZqwvdEROYTH4Drc9P58ujI,1880
|
63
|
+
xax-0.2.13.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
64
|
+
xax-0.2.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|