xax 0.2.5__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.5/xax.egg-info → xax-0.2.6}/PKG-INFO +1 -1
  2. {xax-0.2.5 → xax-0.2.6}/xax/__init__.py +1 -1
  3. {xax-0.2.5 → xax-0.2.6}/xax/task/logger.py +2 -1
  4. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/train.py +8 -5
  5. {xax-0.2.5 → xax-0.2.6/xax.egg-info}/PKG-INFO +1 -1
  6. {xax-0.2.5 → xax-0.2.6}/LICENSE +0 -0
  7. {xax-0.2.5 → xax-0.2.6}/MANIFEST.in +0 -0
  8. {xax-0.2.5 → xax-0.2.6}/README.md +0 -0
  9. {xax-0.2.5 → xax-0.2.6}/pyproject.toml +0 -0
  10. {xax-0.2.5 → xax-0.2.6}/setup.cfg +0 -0
  11. {xax-0.2.5 → xax-0.2.6}/setup.py +0 -0
  12. {xax-0.2.5 → xax-0.2.6}/xax/core/__init__.py +0 -0
  13. {xax-0.2.5 → xax-0.2.6}/xax/core/conf.py +0 -0
  14. {xax-0.2.5 → xax-0.2.6}/xax/core/state.py +0 -0
  15. {xax-0.2.5 → xax-0.2.6}/xax/nn/__init__.py +0 -0
  16. {xax-0.2.5 → xax-0.2.6}/xax/nn/embeddings.py +0 -0
  17. {xax-0.2.5 → xax-0.2.6}/xax/nn/equinox.py +0 -0
  18. {xax-0.2.5 → xax-0.2.6}/xax/nn/export.py +0 -0
  19. {xax-0.2.5 → xax-0.2.6}/xax/nn/functions.py +0 -0
  20. {xax-0.2.5 → xax-0.2.6}/xax/nn/geom.py +0 -0
  21. {xax-0.2.5 → xax-0.2.6}/xax/nn/losses.py +0 -0
  22. {xax-0.2.5 → xax-0.2.6}/xax/nn/norm.py +0 -0
  23. {xax-0.2.5 → xax-0.2.6}/xax/nn/parallel.py +0 -0
  24. {xax-0.2.5 → xax-0.2.6}/xax/nn/ssm.py +0 -0
  25. {xax-0.2.5 → xax-0.2.6}/xax/py.typed +0 -0
  26. {xax-0.2.5 → xax-0.2.6}/xax/requirements-dev.txt +0 -0
  27. {xax-0.2.5 → xax-0.2.6}/xax/requirements.txt +0 -0
  28. {xax-0.2.5 → xax-0.2.6}/xax/task/__init__.py +0 -0
  29. {xax-0.2.5 → xax-0.2.6}/xax/task/base.py +0 -0
  30. {xax-0.2.5 → xax-0.2.6}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.2.5 → xax-0.2.6}/xax/task/launchers/base.py +0 -0
  32. {xax-0.2.5 → xax-0.2.6}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.2.5 → xax-0.2.6}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.2.5 → xax-0.2.6}/xax/task/loggers/__init__.py +0 -0
  35. {xax-0.2.5 → xax-0.2.6}/xax/task/loggers/callback.py +0 -0
  36. {xax-0.2.5 → xax-0.2.6}/xax/task/loggers/json.py +0 -0
  37. {xax-0.2.5 → xax-0.2.6}/xax/task/loggers/state.py +0 -0
  38. {xax-0.2.5 → xax-0.2.6}/xax/task/loggers/stdout.py +0 -0
  39. {xax-0.2.5 → xax-0.2.6}/xax/task/loggers/tensorboard.py +0 -0
  40. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/__init__.py +0 -0
  41. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/artifacts.py +0 -0
  42. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/checkpointing.py +0 -0
  43. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/compile.py +0 -0
  44. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/cpu_stats.py +0 -0
  45. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/data_loader.py +0 -0
  46. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/gpu_stats.py +0 -0
  47. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/process.py +0 -0
  49. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.2.5 → xax-0.2.6}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.2.5 → xax-0.2.6}/xax/task/script.py +0 -0
  52. {xax-0.2.5 → xax-0.2.6}/xax/task/task.py +0 -0
  53. {xax-0.2.5 → xax-0.2.6}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.5 → xax-0.2.6}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.5 → xax-0.2.6}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.5 → xax-0.2.6}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.5 → xax-0.2.6}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.5 → xax-0.2.6}/xax/utils/jax.py +0 -0
  59. {xax-0.2.5 → xax-0.2.6}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.5 → xax-0.2.6}/xax/utils/logging.py +0 -0
  61. {xax-0.2.5 → xax-0.2.6}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.5 → xax-0.2.6}/xax/utils/profile.py +0 -0
  63. {xax-0.2.5 → xax-0.2.6}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.5 → xax-0.2.6}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.5 → xax-0.2.6}/xax/utils/text.py +0 -0
  66. {xax-0.2.5 → xax-0.2.6}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.5 → xax-0.2.6}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.5 → xax-0.2.6}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.5 → xax-0.2.6}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.5 → xax-0.2.6}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.5 → xax-0.2.6}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.5 → xax-0.2.6}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.5"
15
+ __version__ = "0.2.6"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -521,7 +521,8 @@ class LoggerImpl(ABC):
521
521
  Returns:
522
522
  If the logger should log the current step.
523
523
  """
524
- return self.tickers[state.phase].tick(state.elapsed_time_s.item())
524
+ elapsed_time = state.elapsed_time_s.item() if state.phase == "train" else state.valid_elapsed_time_s.item()
525
+ return self.tickers[state.phase].tick(elapsed_time)
525
526
 
526
527
 
527
528
  class ToastHandler(logging.Handler):
@@ -115,19 +115,22 @@ class ValidStepTimer:
115
115
  self.last_valid_time: float | None = None
116
116
  self.last_valid_step: int | None = None
117
117
 
118
+ def _reset(self, state: State) -> None:
119
+ self.last_valid_time = state.elapsed_time_s.item()
120
+ self.last_valid_step = state.num_steps.item()
121
+
118
122
  def is_valid_step(self, state: State) -> bool:
119
123
  if state.num_steps < self.valid_first_n_steps:
120
124
  return True
121
125
 
122
126
  if self.last_valid_time is None or self.last_valid_step is None:
123
- self.last_valid_time = state.elapsed_time_s.item()
124
- self.last_valid_step = state.num_steps.item()
127
+ self._reset(state)
125
128
  return False
126
129
 
127
130
  # Step-based validation.
128
131
  valid_every_n_steps = self.valid_every_n_steps
129
132
  if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
130
- self.last_valid_step = state.num_steps.item()
133
+ self._reset(state)
131
134
  return True
132
135
 
133
136
  # Time-based validation.
@@ -136,14 +139,14 @@ class ValidStepTimer:
136
139
  valid_every_n_seconds is not None
137
140
  and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
138
141
  ):
139
- self.last_valid_time = state.elapsed_time_s.item()
142
+ self._reset(state)
140
143
  return True
141
144
 
142
145
  # Time-based validation for first validation step.
143
146
  if self.first_valid_step_flag:
144
147
  valid_first_n_seconds = self.valid_first_n_seconds
145
148
  if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
146
- self.last_valid_time = state.elapsed_time_s.item()
149
+ self._reset(state)
147
150
  self.first_valid_step_flag = False
148
151
  return True
149
152
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes