xax 0.2.12__tar.gz → 0.2.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.12/xax.egg-info → xax-0.2.13}/PKG-INFO +1 -1
  2. {xax-0.2.12 → xax-0.2.13}/xax/__init__.py +1 -1
  3. {xax-0.2.12 → xax-0.2.13}/xax/core/state.py +21 -11
  4. {xax-0.2.12 → xax-0.2.13}/xax/task/loggers/tensorboard.py +1 -1
  5. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/train.py +17 -9
  6. {xax-0.2.12 → xax-0.2.13/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.2.12 → xax-0.2.13}/LICENSE +0 -0
  8. {xax-0.2.12 → xax-0.2.13}/MANIFEST.in +0 -0
  9. {xax-0.2.12 → xax-0.2.13}/README.md +0 -0
  10. {xax-0.2.12 → xax-0.2.13}/pyproject.toml +0 -0
  11. {xax-0.2.12 → xax-0.2.13}/setup.cfg +0 -0
  12. {xax-0.2.12 → xax-0.2.13}/setup.py +0 -0
  13. {xax-0.2.12 → xax-0.2.13}/xax/core/__init__.py +0 -0
  14. {xax-0.2.12 → xax-0.2.13}/xax/core/conf.py +0 -0
  15. {xax-0.2.12 → xax-0.2.13}/xax/nn/__init__.py +0 -0
  16. {xax-0.2.12 → xax-0.2.13}/xax/nn/embeddings.py +0 -0
  17. {xax-0.2.12 → xax-0.2.13}/xax/nn/equinox.py +0 -0
  18. {xax-0.2.12 → xax-0.2.13}/xax/nn/export.py +0 -0
  19. {xax-0.2.12 → xax-0.2.13}/xax/nn/functions.py +0 -0
  20. {xax-0.2.12 → xax-0.2.13}/xax/nn/geom.py +0 -0
  21. {xax-0.2.12 → xax-0.2.13}/xax/nn/losses.py +0 -0
  22. {xax-0.2.12 → xax-0.2.13}/xax/nn/norm.py +0 -0
  23. {xax-0.2.12 → xax-0.2.13}/xax/nn/parallel.py +0 -0
  24. {xax-0.2.12 → xax-0.2.13}/xax/nn/ssm.py +0 -0
  25. {xax-0.2.12 → xax-0.2.13}/xax/py.typed +0 -0
  26. {xax-0.2.12 → xax-0.2.13}/xax/requirements-dev.txt +0 -0
  27. {xax-0.2.12 → xax-0.2.13}/xax/requirements.txt +0 -0
  28. {xax-0.2.12 → xax-0.2.13}/xax/task/__init__.py +0 -0
  29. {xax-0.2.12 → xax-0.2.13}/xax/task/base.py +0 -0
  30. {xax-0.2.12 → xax-0.2.13}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.2.12 → xax-0.2.13}/xax/task/launchers/base.py +0 -0
  32. {xax-0.2.12 → xax-0.2.13}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.2.12 → xax-0.2.13}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.2.12 → xax-0.2.13}/xax/task/logger.py +0 -0
  35. {xax-0.2.12 → xax-0.2.13}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.2.12 → xax-0.2.13}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.2.12 → xax-0.2.13}/xax/task/loggers/json.py +0 -0
  38. {xax-0.2.12 → xax-0.2.13}/xax/task/loggers/state.py +0 -0
  39. {xax-0.2.12 → xax-0.2.13}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/__init__.py +0 -0
  41. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/artifacts.py +0 -0
  42. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/checkpointing.py +0 -0
  43. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/compile.py +0 -0
  44. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/cpu_stats.py +0 -0
  45. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/data_loader.py +0 -0
  46. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/gpu_stats.py +0 -0
  47. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/process.py +0 -0
  49. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.2.12 → xax-0.2.13}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.2.12 → xax-0.2.13}/xax/task/script.py +0 -0
  52. {xax-0.2.12 → xax-0.2.13}/xax/task/task.py +0 -0
  53. {xax-0.2.12 → xax-0.2.13}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.12 → xax-0.2.13}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.12 → xax-0.2.13}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.12 → xax-0.2.13}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.12 → xax-0.2.13}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.12 → xax-0.2.13}/xax/utils/jax.py +0 -0
  59. {xax-0.2.12 → xax-0.2.13}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.12 → xax-0.2.13}/xax/utils/logging.py +0 -0
  61. {xax-0.2.12 → xax-0.2.13}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.12 → xax-0.2.13}/xax/utils/profile.py +0 -0
  63. {xax-0.2.12 → xax-0.2.13}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.12 → xax-0.2.13}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.12 → xax-0.2.13}/xax/utils/text.py +0 -0
  66. {xax-0.2.12 → xax-0.2.13}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.12 → xax-0.2.13}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.12 → xax-0.2.13}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.12 → xax-0.2.13}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.12 → xax-0.2.13}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.12 → xax-0.2.13}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.12 → xax-0.2.13}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.12
3
+ Version: 0.2.13
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.12"
15
+ __version__ = "0.2.13"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -1,5 +1,6 @@
1
1
  """Defines a dataclass for keeping track of the current training state."""
2
2
 
3
+ import time
3
4
  from dataclasses import dataclass
4
5
  from typing import Literal, NotRequired, TypedDict, Unpack, cast
5
6
 
@@ -28,6 +29,7 @@ class StateDict(TypedDict, total=False):
28
29
  num_samples: NotRequired[int | Array]
29
30
  num_valid_steps: NotRequired[int | Array]
30
31
  num_valid_samples: NotRequired[int | Array]
32
+ start_time_s: NotRequired[float | Array]
31
33
  elapsed_time_s: NotRequired[float | Array]
32
34
  valid_elapsed_time_s: NotRequired[float | Array]
33
35
  phase: NotRequired[Phase]
@@ -57,13 +59,17 @@ class State:
57
59
  return self._float32_arr[1]
58
60
 
59
61
  @property
60
- def elapsed_time_s(self) -> Array:
62
+ def start_time_s(self) -> Array:
61
63
  return self._float32_arr[2]
62
64
 
63
65
  @property
64
- def valid_elapsed_time_s(self) -> Array:
66
+ def elapsed_time_s(self) -> Array:
65
67
  return self._float32_arr[3]
66
68
 
69
+ @property
70
+ def valid_elapsed_time_s(self) -> Array:
71
+ return self._float32_arr[4]
72
+
67
73
  @property
68
74
  def phase(self) -> Phase:
69
75
  return _int_to_phase(self._int32_arr[2].item())
@@ -72,7 +78,7 @@ class State:
72
78
  def init_state(cls) -> "State":
73
79
  return cls(
74
80
  _int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
75
- _float32_arr=jnp.array([0.0, 0.0, 0.0, 0.0], dtype=jnp.float32),
81
+ _float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0, 0.0], dtype=jnp.float32),
76
82
  )
77
83
 
78
84
  @property
@@ -98,10 +104,12 @@ class State:
98
104
  if "num_valid_samples" in kwargs:
99
105
  float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
100
106
 
107
+ if "start_time_s" in kwargs:
108
+ float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
101
109
  if "elapsed_time_s" in kwargs:
102
- float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
110
+ float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
103
111
  if "valid_elapsed_time_s" in kwargs:
104
- float32_arr = float32_arr.at[3].set(kwargs["valid_elapsed_time_s"])
112
+ float32_arr = float32_arr.at[4].set(kwargs["valid_elapsed_time_s"])
105
113
 
106
114
  return State(
107
115
  _int32_arr=int32_arr,
@@ -110,12 +118,13 @@ class State:
110
118
 
111
119
  def to_dict(self) -> dict[str, int | float | str]:
112
120
  return {
113
- "num_steps": int(self.num_steps),
114
- "num_valid_steps": int(self.num_valid_steps),
115
- "num_samples": int(self.num_samples),
116
- "num_valid_samples": int(self.num_valid_samples),
117
- "elapsed_time_s": float(self.elapsed_time_s),
118
- "valid_elapsed_time_s": float(self.valid_elapsed_time_s),
121
+ "num_steps": int(self.num_steps.item()),
122
+ "num_valid_steps": int(self.num_valid_steps.item()),
123
+ "num_samples": int(self.num_samples.item()),
124
+ "num_valid_samples": int(self.num_valid_samples.item()),
125
+ "start_time_s": float(self.start_time_s.item()),
126
+ "elapsed_time_s": float(self.elapsed_time_s.item()),
127
+ "valid_elapsed_time_s": float(self.valid_elapsed_time_s.item()),
119
128
  "phase": str(self.phase),
120
129
  }
121
130
 
@@ -137,6 +146,7 @@ class State:
137
146
  [
138
147
  d.get("num_samples", 0),
139
148
  d.get("num_valid_samples", 0),
149
+ d.get("start_time_s", time.time()),
140
150
  d.get("elapsed_time_s", 0.0),
141
151
  d.get("valid_elapsed_time_s", 0.0),
142
152
  ],
@@ -160,7 +160,7 @@ class TensorboardLogger(LoggerImpl):
160
160
  writer = self.get_writer(line.state.phase)
161
161
 
162
162
  global_step = line.state.num_steps.item()
163
- walltime = line.state.elapsed_time_s.item()
163
+ walltime = line.state.start_time_s.item() + line.state.elapsed_time_s.item()
164
164
 
165
165
  for namespace, scalars in line.scalars.items():
166
166
  for scalar_key, scalar_value in scalars.items():
@@ -728,21 +728,27 @@ class TrainMixin(
728
728
  model_arr, model_static = eqx.partition(model, self.model_partition_fn)
729
729
 
730
730
  while not self.is_training_over(state):
731
- if self.valid_step_timer(state):
731
+ valid_step = self.valid_step_timer(state)
732
+
733
+ if valid_step:
732
734
  with ContextTimer() as timer:
735
+ state = state.replace(phase="valid")
733
736
  valid_batch = next(valid_pf)
734
737
  output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
735
738
  self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
736
739
 
740
+ state = state.replace(
741
+ num_valid_steps=state.num_valid_steps + 1,
742
+ num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
743
+ )
744
+
737
745
  state = state.replace(
738
- phase="valid",
739
- num_valid_steps=state.num_valid_steps + 1,
740
- num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
741
746
  valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
742
747
  )
743
748
 
744
749
  with ContextTimer() as timer:
745
750
  state = self.on_step_start(state)
751
+ state = state.replace(phase="train")
746
752
  train_batch = next(train_pf)
747
753
  model_arr, opt_state, output, metrics = self.train_step(
748
754
  model_arr=model_arr,
@@ -754,15 +760,17 @@ class TrainMixin(
754
760
  )
755
761
  self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
756
762
 
763
+ state = state.replace(
764
+ num_steps=state.num_steps + 1,
765
+ num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
766
+ )
767
+
768
+ state = self.on_step_end(state)
769
+
757
770
  state = state.replace(
758
- phase="train",
759
- num_steps=state.num_steps + 1,
760
- num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
761
771
  elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
762
772
  )
763
773
 
764
- state = self.on_step_end(state)
765
-
766
774
  if self.should_checkpoint(state):
767
775
  model = eqx.combine(model_arr, model_static)
768
776
  self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.12
3
+ Version: 0.2.13
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes