xax 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/core/state.py +21 -11
- xax/task/loggers/tensorboard.py +1 -1
- xax/task/mixins/train.py +17 -9
- {xax-0.2.12.dist-info → xax-0.2.13.dist-info}/METADATA +1 -1
- {xax-0.2.12.dist-info → xax-0.2.13.dist-info}/RECORD +9 -9
- {xax-0.2.12.dist-info → xax-0.2.13.dist-info}/WHEEL +0 -0
- {xax-0.2.12.dist-info → xax-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.12.dist-info → xax-0.2.13.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
xax/core/state.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Defines a dataclass for keeping track of the current training state."""
|
2
2
|
|
3
|
+
import time
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from typing import Literal, NotRequired, TypedDict, Unpack, cast
|
5
6
|
|
@@ -28,6 +29,7 @@ class StateDict(TypedDict, total=False):
|
|
28
29
|
num_samples: NotRequired[int | Array]
|
29
30
|
num_valid_steps: NotRequired[int | Array]
|
30
31
|
num_valid_samples: NotRequired[int | Array]
|
32
|
+
start_time_s: NotRequired[float | Array]
|
31
33
|
elapsed_time_s: NotRequired[float | Array]
|
32
34
|
valid_elapsed_time_s: NotRequired[float | Array]
|
33
35
|
phase: NotRequired[Phase]
|
@@ -57,13 +59,17 @@ class State:
|
|
57
59
|
return self._float32_arr[1]
|
58
60
|
|
59
61
|
@property
|
60
|
-
def
|
62
|
+
def start_time_s(self) -> Array:
|
61
63
|
return self._float32_arr[2]
|
62
64
|
|
63
65
|
@property
|
64
|
-
def
|
66
|
+
def elapsed_time_s(self) -> Array:
|
65
67
|
return self._float32_arr[3]
|
66
68
|
|
69
|
+
@property
|
70
|
+
def valid_elapsed_time_s(self) -> Array:
|
71
|
+
return self._float32_arr[4]
|
72
|
+
|
67
73
|
@property
|
68
74
|
def phase(self) -> Phase:
|
69
75
|
return _int_to_phase(self._int32_arr[2].item())
|
@@ -72,7 +78,7 @@ class State:
|
|
72
78
|
def init_state(cls) -> "State":
|
73
79
|
return cls(
|
74
80
|
_int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
|
75
|
-
_float32_arr=jnp.array([0.0, 0.0, 0.0, 0.0], dtype=jnp.float32),
|
81
|
+
_float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0, 0.0], dtype=jnp.float32),
|
76
82
|
)
|
77
83
|
|
78
84
|
@property
|
@@ -98,10 +104,12 @@ class State:
|
|
98
104
|
if "num_valid_samples" in kwargs:
|
99
105
|
float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
|
100
106
|
|
107
|
+
if "start_time_s" in kwargs:
|
108
|
+
float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
|
101
109
|
if "elapsed_time_s" in kwargs:
|
102
|
-
float32_arr = float32_arr.at[
|
110
|
+
float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
|
103
111
|
if "valid_elapsed_time_s" in kwargs:
|
104
|
-
float32_arr = float32_arr.at[
|
112
|
+
float32_arr = float32_arr.at[4].set(kwargs["valid_elapsed_time_s"])
|
105
113
|
|
106
114
|
return State(
|
107
115
|
_int32_arr=int32_arr,
|
@@ -110,12 +118,13 @@ class State:
|
|
110
118
|
|
111
119
|
def to_dict(self) -> dict[str, int | float | str]:
|
112
120
|
return {
|
113
|
-
"num_steps": int(self.num_steps),
|
114
|
-
"num_valid_steps": int(self.num_valid_steps),
|
115
|
-
"num_samples": int(self.num_samples),
|
116
|
-
"num_valid_samples": int(self.num_valid_samples),
|
117
|
-
"
|
118
|
-
"
|
121
|
+
"num_steps": int(self.num_steps.item()),
|
122
|
+
"num_valid_steps": int(self.num_valid_steps.item()),
|
123
|
+
"num_samples": int(self.num_samples.item()),
|
124
|
+
"num_valid_samples": int(self.num_valid_samples.item()),
|
125
|
+
"start_time_s": float(self.start_time_s.item()),
|
126
|
+
"elapsed_time_s": float(self.elapsed_time_s.item()),
|
127
|
+
"valid_elapsed_time_s": float(self.valid_elapsed_time_s.item()),
|
119
128
|
"phase": str(self.phase),
|
120
129
|
}
|
121
130
|
|
@@ -137,6 +146,7 @@ class State:
|
|
137
146
|
[
|
138
147
|
d.get("num_samples", 0),
|
139
148
|
d.get("num_valid_samples", 0),
|
149
|
+
d.get("start_time_s", time.time()),
|
140
150
|
d.get("elapsed_time_s", 0.0),
|
141
151
|
d.get("valid_elapsed_time_s", 0.0),
|
142
152
|
],
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -160,7 +160,7 @@ class TensorboardLogger(LoggerImpl):
|
|
160
160
|
writer = self.get_writer(line.state.phase)
|
161
161
|
|
162
162
|
global_step = line.state.num_steps.item()
|
163
|
-
walltime = line.state.elapsed_time_s.item()
|
163
|
+
walltime = line.state.start_time_s.item() + line.state.elapsed_time_s.item()
|
164
164
|
|
165
165
|
for namespace, scalars in line.scalars.items():
|
166
166
|
for scalar_key, scalar_value in scalars.items():
|
xax/task/mixins/train.py
CHANGED
@@ -728,21 +728,27 @@ class TrainMixin(
|
|
728
728
|
model_arr, model_static = eqx.partition(model, self.model_partition_fn)
|
729
729
|
|
730
730
|
while not self.is_training_over(state):
|
731
|
-
|
731
|
+
valid_step = self.valid_step_timer(state)
|
732
|
+
|
733
|
+
if valid_step:
|
732
734
|
with ContextTimer() as timer:
|
735
|
+
state = state.replace(phase="valid")
|
733
736
|
valid_batch = next(valid_pf)
|
734
737
|
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
735
738
|
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
736
739
|
|
740
|
+
state = state.replace(
|
741
|
+
num_valid_steps=state.num_valid_steps + 1,
|
742
|
+
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
743
|
+
)
|
744
|
+
|
737
745
|
state = state.replace(
|
738
|
-
phase="valid",
|
739
|
-
num_valid_steps=state.num_valid_steps + 1,
|
740
|
-
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
741
746
|
valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
|
742
747
|
)
|
743
748
|
|
744
749
|
with ContextTimer() as timer:
|
745
750
|
state = self.on_step_start(state)
|
751
|
+
state = state.replace(phase="train")
|
746
752
|
train_batch = next(train_pf)
|
747
753
|
model_arr, opt_state, output, metrics = self.train_step(
|
748
754
|
model_arr=model_arr,
|
@@ -754,15 +760,17 @@ class TrainMixin(
|
|
754
760
|
)
|
755
761
|
self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
|
756
762
|
|
763
|
+
state = state.replace(
|
764
|
+
num_steps=state.num_steps + 1,
|
765
|
+
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
766
|
+
)
|
767
|
+
|
768
|
+
state = self.on_step_end(state)
|
769
|
+
|
757
770
|
state = state.replace(
|
758
|
-
phase="train",
|
759
|
-
num_steps=state.num_steps + 1,
|
760
|
-
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
761
771
|
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
762
772
|
)
|
763
773
|
|
764
|
-
state = self.on_step_end(state)
|
765
|
-
|
766
774
|
if self.should_checkpoint(state):
|
767
775
|
model = eqx.combine(model_arr, model_static)
|
768
776
|
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
@@ -1,10 +1,10 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=33wIwGeXDFReg2ZnFqUHfSybj5cKyMqnI8ncj8-9yVg,15510
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
7
|
-
xax/core/state.py,sha256=
|
7
|
+
xax/core/state.py,sha256=KsNMnM_RgsZ2Ntc2pp4Fi6zG4rZb_89-kqmyGxDvyRg,4974
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
10
|
xax/nn/equinox.py,sha256=JZuSApD4bL0UK5W1nrQtucWYvNWUha07J6LTLk_RX-Y,4910
|
@@ -29,7 +29,7 @@ xax/task/loggers/callback.py,sha256=zQuV1xCvz47Q3UQqP1D5mBhbVzptvmPR_7hX25vqSk0,
|
|
29
29
|
xax/task/loggers/json.py,sha256=6A5wL7kspsXnpPhI_vu0scgd2Z2-WLhw4gbBFm7eZMM,4377
|
30
30
|
xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,1407
|
31
31
|
xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
|
32
|
-
xax/task/loggers/tensorboard.py,sha256=
|
32
|
+
xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=Ma7fwsp-SA1w6GcuBSskszj5TB83yxYJm4Ns_EnqkI4,3018
|
35
35
|
xax/task/mixins/checkpointing.py,sha256=zqospBFnTbGt_iriiduVfXazINPbzWpwmIs91KAniMY,10147
|
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=_QoxSDMW6nmpH82Un2LDsVIBg9KIx8npRwSjY4TEGYA,31830
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
47
|
xax/utils/experiments.py,sha256=bj8BftSHT3fFzfiJ0Co0WvqWo0rUS8kQnQYpVvH8FTM,29942
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.13.dist-info/METADATA,sha256=-foHRw3ph7yxBmMmjO_oqZqwvdEROYTH4Drc9P58ujI,1880
|
63
|
+
xax-0.2.13.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
|
64
|
+
xax-0.2.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.13.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|