xax 0.1.1__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {xax-0.1.1/xax.egg-info → xax-0.1.2}/PKG-INFO +1 -1
  2. {xax-0.1.1 → xax-0.1.2}/xax/__init__.py +4 -2
  3. {xax-0.1.1 → xax-0.1.2}/xax/utils/experiments.py +2 -2
  4. {xax-0.1.1 → xax-0.1.2/xax.egg-info}/PKG-INFO +1 -1
  5. {xax-0.1.1 → xax-0.1.2}/LICENSE +0 -0
  6. {xax-0.1.1 → xax-0.1.2}/MANIFEST.in +0 -0
  7. {xax-0.1.1 → xax-0.1.2}/README.md +0 -0
  8. {xax-0.1.1 → xax-0.1.2}/pyproject.toml +0 -0
  9. {xax-0.1.1 → xax-0.1.2}/setup.cfg +0 -0
  10. {xax-0.1.1 → xax-0.1.2}/setup.py +0 -0
  11. {xax-0.1.1 → xax-0.1.2}/xax/core/__init__.py +0 -0
  12. {xax-0.1.1 → xax-0.1.2}/xax/core/conf.py +0 -0
  13. {xax-0.1.1 → xax-0.1.2}/xax/core/state.py +0 -0
  14. {xax-0.1.1 → xax-0.1.2}/xax/nn/__init__.py +0 -0
  15. {xax-0.1.1 → xax-0.1.2}/xax/nn/embeddings.py +0 -0
  16. {xax-0.1.1 → xax-0.1.2}/xax/nn/equinox.py +0 -0
  17. {xax-0.1.1 → xax-0.1.2}/xax/nn/export.py +0 -0
  18. {xax-0.1.1 → xax-0.1.2}/xax/nn/functions.py +0 -0
  19. {xax-0.1.1 → xax-0.1.2}/xax/nn/geom.py +0 -0
  20. {xax-0.1.1 → xax-0.1.2}/xax/nn/norm.py +0 -0
  21. {xax-0.1.1 → xax-0.1.2}/xax/nn/parallel.py +0 -0
  22. {xax-0.1.1 → xax-0.1.2}/xax/py.typed +0 -0
  23. {xax-0.1.1 → xax-0.1.2}/xax/requirements-dev.txt +0 -0
  24. {xax-0.1.1 → xax-0.1.2}/xax/requirements.txt +0 -0
  25. {xax-0.1.1 → xax-0.1.2}/xax/task/__init__.py +0 -0
  26. {xax-0.1.1 → xax-0.1.2}/xax/task/base.py +0 -0
  27. {xax-0.1.1 → xax-0.1.2}/xax/task/launchers/__init__.py +0 -0
  28. {xax-0.1.1 → xax-0.1.2}/xax/task/launchers/base.py +0 -0
  29. {xax-0.1.1 → xax-0.1.2}/xax/task/launchers/cli.py +0 -0
  30. {xax-0.1.1 → xax-0.1.2}/xax/task/launchers/single_process.py +0 -0
  31. {xax-0.1.1 → xax-0.1.2}/xax/task/logger.py +0 -0
  32. {xax-0.1.1 → xax-0.1.2}/xax/task/loggers/__init__.py +0 -0
  33. {xax-0.1.1 → xax-0.1.2}/xax/task/loggers/callback.py +0 -0
  34. {xax-0.1.1 → xax-0.1.2}/xax/task/loggers/json.py +0 -0
  35. {xax-0.1.1 → xax-0.1.2}/xax/task/loggers/state.py +0 -0
  36. {xax-0.1.1 → xax-0.1.2}/xax/task/loggers/stdout.py +0 -0
  37. {xax-0.1.1 → xax-0.1.2}/xax/task/loggers/tensorboard.py +0 -0
  38. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/__init__.py +0 -0
  39. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/artifacts.py +0 -0
  40. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/checkpointing.py +0 -0
  41. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/compile.py +0 -0
  42. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/cpu_stats.py +0 -0
  43. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/data_loader.py +0 -0
  44. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/gpu_stats.py +0 -0
  45. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/logger.py +0 -0
  46. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/process.py +0 -0
  47. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/runnable.py +0 -0
  48. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/step_wrapper.py +0 -0
  49. {xax-0.1.1 → xax-0.1.2}/xax/task/mixins/train.py +0 -0
  50. {xax-0.1.1 → xax-0.1.2}/xax/task/script.py +0 -0
  51. {xax-0.1.1 → xax-0.1.2}/xax/task/task.py +0 -0
  52. {xax-0.1.1 → xax-0.1.2}/xax/utils/__init__.py +0 -0
  53. {xax-0.1.1 → xax-0.1.2}/xax/utils/data/__init__.py +0 -0
  54. {xax-0.1.1 → xax-0.1.2}/xax/utils/data/collate.py +0 -0
  55. {xax-0.1.1 → xax-0.1.2}/xax/utils/debugging.py +0 -0
  56. {xax-0.1.1 → xax-0.1.2}/xax/utils/jax.py +0 -0
  57. {xax-0.1.1 → xax-0.1.2}/xax/utils/jaxpr.py +0 -0
  58. {xax-0.1.1 → xax-0.1.2}/xax/utils/logging.py +0 -0
  59. {xax-0.1.1 → xax-0.1.2}/xax/utils/numpy.py +0 -0
  60. {xax-0.1.1 → xax-0.1.2}/xax/utils/profile.py +0 -0
  61. {xax-0.1.1 → xax-0.1.2}/xax/utils/pytree.py +0 -0
  62. {xax-0.1.1 → xax-0.1.2}/xax/utils/tensorboard.py +0 -0
  63. {xax-0.1.1 → xax-0.1.2}/xax/utils/text.py +0 -0
  64. {xax-0.1.1 → xax-0.1.2}/xax.egg-info/SOURCES.txt +0 -0
  65. {xax-0.1.1 → xax-0.1.2}/xax.egg-info/dependency_links.txt +0 -0
  66. {xax-0.1.1 → xax-0.1.2}/xax.egg-info/requires.txt +0 -0
  67. {xax-0.1.1 → xax-0.1.2}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.1"
15
+ __version__ = "0.1.2"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -66,6 +66,7 @@ __all__ = [
66
66
  "DataloaderConfig",
67
67
  "GPUStatsOptions",
68
68
  "StepContext",
69
+ "ValidStepTimer",
69
70
  "Script",
70
71
  "ScriptConfig",
71
72
  "Config",
@@ -219,6 +220,7 @@ NAME_MAP: dict[str, str] = {
219
220
  "DataloaderConfig": "task.mixins.data_loader",
220
221
  "GPUStatsOptions": "task.mixins.gpu_stats",
221
222
  "StepContext": "task.mixins.step_wrapper",
223
+ "ValidStepTimer": "task.mixins.train",
222
224
  "Script": "task.script",
223
225
  "ScriptConfig": "task.script",
224
226
  "Config": "task.task",
@@ -372,7 +374,7 @@ if IMPORT_ALL or TYPE_CHECKING:
372
374
  from xax.task.mixins.data_loader import DataloaderConfig
373
375
  from xax.task.mixins.gpu_stats import GPUStatsOptions
374
376
  from xax.task.mixins.step_wrapper import StepContext
375
- from xax.task.mixins.train import Batch, Output
377
+ from xax.task.mixins.train import Batch, Output, ValidStepTimer
376
378
  from xax.task.script import Script, ScriptConfig
377
379
  from xax.task.task import Config, Task
378
380
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
@@ -109,8 +109,8 @@ class StateTimer:
109
109
 
110
110
  def step(self, state: State) -> None:
111
111
  cur_time = time.time()
112
- self.step_timer.step(state.num_steps, cur_time)
113
- self.sample_timer.step(state.num_samples, cur_time)
112
+ self.step_timer.step(state.num_steps if state.phase == "train" else state.num_valid_steps, cur_time)
113
+ self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
114
114
  self.iter_timer.step(cur_time)
115
115
 
116
116
  def log_dict(self) -> dict[str, dict[str, int | float]]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes