xax 0.2.3__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.3/xax.egg-info → xax-0.2.4}/PKG-INFO +1 -1
- {xax-0.2.3 → xax-0.2.4}/xax/__init__.py +1 -1
- {xax-0.2.3 → xax-0.2.4}/xax/core/state.py +18 -17
- {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/tensorboard.py +1 -1
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/train.py +8 -7
- {xax-0.2.3 → xax-0.2.4/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.3 → xax-0.2.4}/LICENSE +0 -0
- {xax-0.2.3 → xax-0.2.4}/MANIFEST.in +0 -0
- {xax-0.2.3 → xax-0.2.4}/README.md +0 -0
- {xax-0.2.3 → xax-0.2.4}/pyproject.toml +0 -0
- {xax-0.2.3 → xax-0.2.4}/setup.cfg +0 -0
- {xax-0.2.3 → xax-0.2.4}/setup.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/core/__init__.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/core/conf.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/__init__.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/embeddings.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/equinox.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/export.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/functions.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/geom.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/losses.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/norm.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/parallel.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/nn/ssm.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/py.typed +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/requirements-dev.txt +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/requirements.txt +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/__init__.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/base.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/launchers/base.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/logger.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/json.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/state.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/process.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/script.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/task/task.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/__init__.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/data/collate.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/debugging.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/experiments.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/jax.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/logging.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/numpy.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/profile.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/pytree.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/text.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.3 → xax-0.2.4}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,5 @@
|
|
1
1
|
"""Defines a dataclass for keeping track of the current training state."""
|
2
2
|
|
3
|
-
import time
|
4
3
|
from dataclasses import dataclass
|
5
4
|
from typing import Literal, NotRequired, TypedDict, Unpack, cast
|
6
5
|
|
@@ -19,6 +18,8 @@ def _phase_to_int(phase: Phase) -> int:
|
|
19
18
|
|
20
19
|
|
21
20
|
def _int_to_phase(i: int) -> Phase:
|
21
|
+
if i < 0 or i > 1:
|
22
|
+
raise ValueError(f"Invalid phase: {i}")
|
22
23
|
return cast(Phase, ["train", "valid"][i])
|
23
24
|
|
24
25
|
|
@@ -27,8 +28,8 @@ class StateDict(TypedDict, total=False):
|
|
27
28
|
num_samples: NotRequired[int | Array]
|
28
29
|
num_valid_steps: NotRequired[int | Array]
|
29
30
|
num_valid_samples: NotRequired[int | Array]
|
30
|
-
start_time_s: NotRequired[float | Array]
|
31
31
|
elapsed_time_s: NotRequired[float | Array]
|
32
|
+
valid_elapsed_time_s: NotRequired[float | Array]
|
32
33
|
phase: NotRequired[Phase]
|
33
34
|
_phase: NotRequired[int | Array]
|
34
35
|
|
@@ -43,24 +44,24 @@ class State:
|
|
43
44
|
def num_steps(self) -> Array:
|
44
45
|
return self._int32_arr[0]
|
45
46
|
|
46
|
-
@property
|
47
|
-
def num_samples(self) -> Array:
|
48
|
-
return self._float32_arr[0]
|
49
|
-
|
50
47
|
@property
|
51
48
|
def num_valid_steps(self) -> Array:
|
52
49
|
return self._int32_arr[1]
|
53
50
|
|
51
|
+
@property
|
52
|
+
def num_samples(self) -> Array:
|
53
|
+
return self._float32_arr[0]
|
54
|
+
|
54
55
|
@property
|
55
56
|
def num_valid_samples(self) -> Array:
|
56
57
|
return self._float32_arr[1]
|
57
58
|
|
58
59
|
@property
|
59
|
-
def
|
60
|
+
def elapsed_time_s(self) -> Array:
|
60
61
|
return self._float32_arr[2]
|
61
62
|
|
62
63
|
@property
|
63
|
-
def
|
64
|
+
def valid_elapsed_time_s(self) -> Array:
|
64
65
|
return self._float32_arr[3]
|
65
66
|
|
66
67
|
@property
|
@@ -71,7 +72,7 @@ class State:
|
|
71
72
|
def init_state(cls) -> "State":
|
72
73
|
return cls(
|
73
74
|
_int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
|
74
|
-
_float32_arr=jnp.array([0.0, 0.0,
|
75
|
+
_float32_arr=jnp.array([0.0, 0.0, 0.0, 0.0], dtype=jnp.float32),
|
75
76
|
)
|
76
77
|
|
77
78
|
@property
|
@@ -97,10 +98,10 @@ class State:
|
|
97
98
|
if "num_valid_samples" in kwargs:
|
98
99
|
float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
|
99
100
|
|
100
|
-
if "start_time_s" in kwargs:
|
101
|
-
float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
|
102
101
|
if "elapsed_time_s" in kwargs:
|
103
|
-
float32_arr = float32_arr.at[
|
102
|
+
float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
|
103
|
+
if "valid_elapsed_time_s" in kwargs:
|
104
|
+
float32_arr = float32_arr.at[3].set(kwargs["valid_elapsed_time_s"])
|
104
105
|
|
105
106
|
return State(
|
106
107
|
_int32_arr=int32_arr,
|
@@ -110,11 +111,11 @@ class State:
|
|
110
111
|
def to_dict(self) -> dict[str, int | float | str]:
|
111
112
|
return {
|
112
113
|
"num_steps": int(self.num_steps),
|
113
|
-
"num_samples": int(self.num_samples),
|
114
114
|
"num_valid_steps": int(self.num_valid_steps),
|
115
|
+
"num_samples": int(self.num_samples),
|
115
116
|
"num_valid_samples": int(self.num_valid_samples),
|
116
|
-
"start_time_s": float(self.start_time_s),
|
117
117
|
"elapsed_time_s": float(self.elapsed_time_s),
|
118
|
+
"valid_elapsed_time_s": float(self.valid_elapsed_time_s),
|
118
119
|
"phase": str(self.phase),
|
119
120
|
}
|
120
121
|
|
@@ -126,9 +127,7 @@ class State:
|
|
126
127
|
int32_arr = jnp.array(
|
127
128
|
[
|
128
129
|
d.get("num_steps", 0),
|
129
|
-
d.get("num_samples", 0),
|
130
130
|
d.get("num_valid_steps", 0),
|
131
|
-
d.get("num_valid_samples", 0),
|
132
131
|
d.get("_phase", 0),
|
133
132
|
],
|
134
133
|
dtype=jnp.int32,
|
@@ -136,8 +135,10 @@ class State:
|
|
136
135
|
|
137
136
|
float32_arr = jnp.array(
|
138
137
|
[
|
139
|
-
d.get("
|
138
|
+
d.get("num_samples", 0),
|
139
|
+
d.get("num_valid_samples", 0),
|
140
140
|
d.get("elapsed_time_s", 0.0),
|
141
|
+
d.get("valid_elapsed_time_s", 0.0),
|
141
142
|
],
|
142
143
|
dtype=jnp.float32,
|
143
144
|
)
|
@@ -157,7 +157,7 @@ class TensorboardLogger(LoggerImpl):
|
|
157
157
|
writer = self.get_writer(line.state.phase)
|
158
158
|
|
159
159
|
global_step = line.state.num_steps.item()
|
160
|
-
walltime =
|
160
|
+
walltime = line.state.elapsed_time_s.item()
|
161
161
|
|
162
162
|
for namespace, scalars in line.scalars.items():
|
163
163
|
for scalar_key, scalar_value in scalars.items():
|
@@ -720,20 +720,21 @@ class TrainMixin(
|
|
720
720
|
|
721
721
|
while not self.is_training_over(state):
|
722
722
|
if self.valid_step_timer.is_valid_step(state):
|
723
|
-
|
723
|
+
with ContextTimer() as timer:
|
724
|
+
valid_batch = next(valid_pf)
|
725
|
+
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
726
|
+
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
727
|
+
|
724
728
|
state = state.replace(
|
725
729
|
phase="valid",
|
726
730
|
num_valid_steps=state.num_valid_steps + 1,
|
727
731
|
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
732
|
+
valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
|
728
733
|
)
|
729
734
|
|
730
|
-
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
731
|
-
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
732
|
-
|
733
|
-
state = self.on_step_start(state)
|
734
|
-
train_batch = next(train_pf)
|
735
|
-
|
736
735
|
with ContextTimer() as timer:
|
736
|
+
state = self.on_step_start(state)
|
737
|
+
train_batch = next(train_pf)
|
737
738
|
model_arr, opt_state, output, metrics = self.train_step(
|
738
739
|
model_arr=model_arr,
|
739
740
|
model_static=model_static,
|
{xax-0.2.3 → xax-0.2.4}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.2.3 → xax-0.2.4}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|