xax 0.2.11__tar.gz → 0.2.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.11/xax.egg-info → xax-0.2.12}/PKG-INFO +1 -1
  2. {xax-0.2.11 → xax-0.2.12}/xax/__init__.py +1 -1
  3. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/train.py +13 -7
  4. {xax-0.2.11 → xax-0.2.12/xax.egg-info}/PKG-INFO +1 -1
  5. {xax-0.2.11 → xax-0.2.12}/LICENSE +0 -0
  6. {xax-0.2.11 → xax-0.2.12}/MANIFEST.in +0 -0
  7. {xax-0.2.11 → xax-0.2.12}/README.md +0 -0
  8. {xax-0.2.11 → xax-0.2.12}/pyproject.toml +0 -0
  9. {xax-0.2.11 → xax-0.2.12}/setup.cfg +0 -0
  10. {xax-0.2.11 → xax-0.2.12}/setup.py +0 -0
  11. {xax-0.2.11 → xax-0.2.12}/xax/core/__init__.py +0 -0
  12. {xax-0.2.11 → xax-0.2.12}/xax/core/conf.py +0 -0
  13. {xax-0.2.11 → xax-0.2.12}/xax/core/state.py +0 -0
  14. {xax-0.2.11 → xax-0.2.12}/xax/nn/__init__.py +0 -0
  15. {xax-0.2.11 → xax-0.2.12}/xax/nn/embeddings.py +0 -0
  16. {xax-0.2.11 → xax-0.2.12}/xax/nn/equinox.py +0 -0
  17. {xax-0.2.11 → xax-0.2.12}/xax/nn/export.py +0 -0
  18. {xax-0.2.11 → xax-0.2.12}/xax/nn/functions.py +0 -0
  19. {xax-0.2.11 → xax-0.2.12}/xax/nn/geom.py +0 -0
  20. {xax-0.2.11 → xax-0.2.12}/xax/nn/losses.py +0 -0
  21. {xax-0.2.11 → xax-0.2.12}/xax/nn/norm.py +0 -0
  22. {xax-0.2.11 → xax-0.2.12}/xax/nn/parallel.py +0 -0
  23. {xax-0.2.11 → xax-0.2.12}/xax/nn/ssm.py +0 -0
  24. {xax-0.2.11 → xax-0.2.12}/xax/py.typed +0 -0
  25. {xax-0.2.11 → xax-0.2.12}/xax/requirements-dev.txt +0 -0
  26. {xax-0.2.11 → xax-0.2.12}/xax/requirements.txt +0 -0
  27. {xax-0.2.11 → xax-0.2.12}/xax/task/__init__.py +0 -0
  28. {xax-0.2.11 → xax-0.2.12}/xax/task/base.py +0 -0
  29. {xax-0.2.11 → xax-0.2.12}/xax/task/launchers/__init__.py +0 -0
  30. {xax-0.2.11 → xax-0.2.12}/xax/task/launchers/base.py +0 -0
  31. {xax-0.2.11 → xax-0.2.12}/xax/task/launchers/cli.py +0 -0
  32. {xax-0.2.11 → xax-0.2.12}/xax/task/launchers/single_process.py +0 -0
  33. {xax-0.2.11 → xax-0.2.12}/xax/task/logger.py +0 -0
  34. {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/__init__.py +0 -0
  35. {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/callback.py +0 -0
  36. {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/json.py +0 -0
  37. {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/state.py +0 -0
  38. {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/stdout.py +0 -0
  39. {xax-0.2.11 → xax-0.2.12}/xax/task/loggers/tensorboard.py +0 -0
  40. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/__init__.py +0 -0
  41. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/artifacts.py +0 -0
  42. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/checkpointing.py +0 -0
  43. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/compile.py +0 -0
  44. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/cpu_stats.py +0 -0
  45. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/data_loader.py +0 -0
  46. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/gpu_stats.py +0 -0
  47. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/process.py +0 -0
  49. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.2.11 → xax-0.2.12}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.2.11 → xax-0.2.12}/xax/task/script.py +0 -0
  52. {xax-0.2.11 → xax-0.2.12}/xax/task/task.py +0 -0
  53. {xax-0.2.11 → xax-0.2.12}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.11 → xax-0.2.12}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.11 → xax-0.2.12}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.11 → xax-0.2.12}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.11 → xax-0.2.12}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.11 → xax-0.2.12}/xax/utils/jax.py +0 -0
  59. {xax-0.2.11 → xax-0.2.12}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.11 → xax-0.2.12}/xax/utils/logging.py +0 -0
  61. {xax-0.2.11 → xax-0.2.12}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.11 → xax-0.2.12}/xax/utils/profile.py +0 -0
  63. {xax-0.2.11 → xax-0.2.12}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.11 → xax-0.2.12}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.11 → xax-0.2.12}/xax/utils/text.py +0 -0
  66. {xax-0.2.11 → xax-0.2.12}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.11 → xax-0.2.12}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.11 → xax-0.2.12}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.11 → xax-0.2.12}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.11 → xax-0.2.12}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.11 → xax-0.2.12}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.11 → xax-0.2.12}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.11
3
+ Version: 0.2.12
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.11"
15
+ __version__ = "0.2.12"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -120,7 +120,7 @@ class ValidStepTimer:
120
120
  self.last_valid_time = state.elapsed_time_s.item()
121
121
  self.last_valid_step = state.num_steps.item()
122
122
 
123
- def is_valid_step(self, state: State) -> bool:
123
+ def __call__(self, state: State) -> bool:
124
124
  if state.num_steps < self.valid_first_n_steps and state.num_valid_steps < self.valid_first_n_steps:
125
125
  return True
126
126
 
@@ -130,15 +130,18 @@ class ValidStepTimer:
130
130
 
131
131
  # Step-based validation.
132
132
  valid_every_n_steps = self.valid_every_n_steps
133
- if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
133
+ if valid_every_n_steps is not None and (
134
+ state.num_steps >= valid_every_n_steps + self.last_valid_step
135
+ or state.num_valid_steps >= valid_every_n_steps + self.last_valid_step
136
+ ):
134
137
  self._reset(state)
135
138
  return True
136
139
 
137
140
  # Time-based validation.
138
141
  valid_every_n_seconds = self.valid_every_n_seconds
139
- if (
140
- valid_every_n_seconds is not None
141
- and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
142
+ if valid_every_n_seconds is not None and (
143
+ state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
144
+ or state.valid_elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
142
145
  ):
143
146
  self._reset(state)
144
147
  return True
@@ -146,7 +149,10 @@ class ValidStepTimer:
146
149
  # Time-based validation for first validation step.
147
150
  if self.first_valid_step_flag:
148
151
  valid_first_n_seconds = self.valid_first_n_seconds
149
- if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
152
+ if valid_first_n_seconds is not None and (
153
+ state.elapsed_time_s.item() >= valid_first_n_seconds
154
+ or state.valid_elapsed_time_s.item() >= valid_first_n_seconds
155
+ ):
150
156
  self._reset(state)
151
157
  self.first_valid_step_flag = False
152
158
  return True
@@ -722,7 +728,7 @@ class TrainMixin(
722
728
  model_arr, model_static = eqx.partition(model, self.model_partition_fn)
723
729
 
724
730
  while not self.is_training_over(state):
725
- if self.valid_step_timer.is_valid_step(state):
731
+ if self.valid_step_timer(state):
726
732
  with ContextTimer() as timer:
727
733
  valid_batch = next(valid_pf)
728
734
  output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.11
3
+ Version: 0.2.12
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes