xax 0.2.11__tar.gz → 0.2.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.11/xax.egg-info → xax-0.2.12}/PKG-INFO +1 -1
- {xax-0.2.11 → xax-0.2.12}/xax/__init__.py +1 -1
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/train.py +13 -7
- {xax-0.2.11 → xax-0.2.12/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.11 → xax-0.2.12}/LICENSE +0 -0
- {xax-0.2.11 → xax-0.2.12}/MANIFEST.in +0 -0
- {xax-0.2.11 → xax-0.2.12}/README.md +0 -0
- {xax-0.2.11 → xax-0.2.12}/pyproject.toml +0 -0
- {xax-0.2.11 → xax-0.2.12}/setup.cfg +0 -0
- {xax-0.2.11 → xax-0.2.12}/setup.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/core/__init__.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/core/conf.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/core/state.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/__init__.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/embeddings.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/equinox.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/export.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/functions.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/geom.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/losses.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/norm.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/parallel.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/nn/ssm.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/py.typed +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/requirements-dev.txt +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/requirements.txt +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/__init__.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/base.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/launchers/base.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/logger.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/json.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/state.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/process.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/script.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/task/task.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/__init__.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/data/collate.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/debugging.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/experiments.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/jax.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/logging.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/numpy.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/profile.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/pytree.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/text.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.11 → xax-0.2.12}/xax.egg-info/top_level.txt +0 -0
@@ -120,7 +120,7 @@ class ValidStepTimer:
|
|
120
120
|
self.last_valid_time = state.elapsed_time_s.item()
|
121
121
|
self.last_valid_step = state.num_steps.item()
|
122
122
|
|
123
|
-
def
|
123
|
+
def __call__(self, state: State) -> bool:
|
124
124
|
if state.num_steps < self.valid_first_n_steps and state.num_valid_steps < self.valid_first_n_steps:
|
125
125
|
return True
|
126
126
|
|
@@ -130,15 +130,18 @@ class ValidStepTimer:
|
|
130
130
|
|
131
131
|
# Step-based validation.
|
132
132
|
valid_every_n_steps = self.valid_every_n_steps
|
133
|
-
if valid_every_n_steps is not None and
|
133
|
+
if valid_every_n_steps is not None and (
|
134
|
+
state.num_steps >= valid_every_n_steps + self.last_valid_step
|
135
|
+
or state.num_valid_steps >= valid_every_n_steps + self.last_valid_step
|
136
|
+
):
|
134
137
|
self._reset(state)
|
135
138
|
return True
|
136
139
|
|
137
140
|
# Time-based validation.
|
138
141
|
valid_every_n_seconds = self.valid_every_n_seconds
|
139
|
-
if (
|
140
|
-
|
141
|
-
|
142
|
+
if valid_every_n_seconds is not None and (
|
143
|
+
state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
|
144
|
+
or state.valid_elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
|
142
145
|
):
|
143
146
|
self._reset(state)
|
144
147
|
return True
|
@@ -146,7 +149,10 @@ class ValidStepTimer:
|
|
146
149
|
# Time-based validation for first validation step.
|
147
150
|
if self.first_valid_step_flag:
|
148
151
|
valid_first_n_seconds = self.valid_first_n_seconds
|
149
|
-
if valid_first_n_seconds is not None and
|
152
|
+
if valid_first_n_seconds is not None and (
|
153
|
+
state.elapsed_time_s.item() >= valid_first_n_seconds
|
154
|
+
or state.valid_elapsed_time_s.item() >= valid_first_n_seconds
|
155
|
+
):
|
150
156
|
self._reset(state)
|
151
157
|
self.first_valid_step_flag = False
|
152
158
|
return True
|
@@ -722,7 +728,7 @@ class TrainMixin(
|
|
722
728
|
model_arr, model_static = eqx.partition(model, self.model_partition_fn)
|
723
729
|
|
724
730
|
while not self.is_training_over(state):
|
725
|
-
if self.valid_step_timer
|
731
|
+
if self.valid_step_timer(state):
|
726
732
|
with ContextTimer() as timer:
|
727
733
|
valid_batch = next(valid_pf)
|
728
734
|
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|