xax 0.2.3__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.3/xax.egg-info → xax-0.2.4}/PKG-INFO +1 -1
  2. {xax-0.2.3 → xax-0.2.4}/xax/__init__.py +1 -1
  3. {xax-0.2.3 → xax-0.2.4}/xax/core/state.py +18 -17
  4. {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/tensorboard.py +1 -1
  5. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/train.py +8 -7
  6. {xax-0.2.3 → xax-0.2.4/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.2.3 → xax-0.2.4}/LICENSE +0 -0
  8. {xax-0.2.3 → xax-0.2.4}/MANIFEST.in +0 -0
  9. {xax-0.2.3 → xax-0.2.4}/README.md +0 -0
  10. {xax-0.2.3 → xax-0.2.4}/pyproject.toml +0 -0
  11. {xax-0.2.3 → xax-0.2.4}/setup.cfg +0 -0
  12. {xax-0.2.3 → xax-0.2.4}/setup.py +0 -0
  13. {xax-0.2.3 → xax-0.2.4}/xax/core/__init__.py +0 -0
  14. {xax-0.2.3 → xax-0.2.4}/xax/core/conf.py +0 -0
  15. {xax-0.2.3 → xax-0.2.4}/xax/nn/__init__.py +0 -0
  16. {xax-0.2.3 → xax-0.2.4}/xax/nn/embeddings.py +0 -0
  17. {xax-0.2.3 → xax-0.2.4}/xax/nn/equinox.py +0 -0
  18. {xax-0.2.3 → xax-0.2.4}/xax/nn/export.py +0 -0
  19. {xax-0.2.3 → xax-0.2.4}/xax/nn/functions.py +0 -0
  20. {xax-0.2.3 → xax-0.2.4}/xax/nn/geom.py +0 -0
  21. {xax-0.2.3 → xax-0.2.4}/xax/nn/losses.py +0 -0
  22. {xax-0.2.3 → xax-0.2.4}/xax/nn/norm.py +0 -0
  23. {xax-0.2.3 → xax-0.2.4}/xax/nn/parallel.py +0 -0
  24. {xax-0.2.3 → xax-0.2.4}/xax/nn/ssm.py +0 -0
  25. {xax-0.2.3 → xax-0.2.4}/xax/py.typed +0 -0
  26. {xax-0.2.3 → xax-0.2.4}/xax/requirements-dev.txt +0 -0
  27. {xax-0.2.3 → xax-0.2.4}/xax/requirements.txt +0 -0
  28. {xax-0.2.3 → xax-0.2.4}/xax/task/__init__.py +0 -0
  29. {xax-0.2.3 → xax-0.2.4}/xax/task/base.py +0 -0
  30. {xax-0.2.3 → xax-0.2.4}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.2.3 → xax-0.2.4}/xax/task/launchers/base.py +0 -0
  32. {xax-0.2.3 → xax-0.2.4}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.2.3 → xax-0.2.4}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.2.3 → xax-0.2.4}/xax/task/logger.py +0 -0
  35. {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/json.py +0 -0
  38. {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/state.py +0 -0
  39. {xax-0.2.3 → xax-0.2.4}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/__init__.py +0 -0
  41. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/artifacts.py +0 -0
  42. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/checkpointing.py +0 -0
  43. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/compile.py +0 -0
  44. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/cpu_stats.py +0 -0
  45. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/data_loader.py +0 -0
  46. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/gpu_stats.py +0 -0
  47. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/process.py +0 -0
  49. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.2.3 → xax-0.2.4}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.2.3 → xax-0.2.4}/xax/task/script.py +0 -0
  52. {xax-0.2.3 → xax-0.2.4}/xax/task/task.py +0 -0
  53. {xax-0.2.3 → xax-0.2.4}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.3 → xax-0.2.4}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.3 → xax-0.2.4}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.3 → xax-0.2.4}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.3 → xax-0.2.4}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.3 → xax-0.2.4}/xax/utils/jax.py +0 -0
  59. {xax-0.2.3 → xax-0.2.4}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.3 → xax-0.2.4}/xax/utils/logging.py +0 -0
  61. {xax-0.2.3 → xax-0.2.4}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.3 → xax-0.2.4}/xax/utils/profile.py +0 -0
  63. {xax-0.2.3 → xax-0.2.4}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.3 → xax-0.2.4}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.3 → xax-0.2.4}/xax/utils/text.py +0 -0
  66. {xax-0.2.3 → xax-0.2.4}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.3 → xax-0.2.4}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.3 → xax-0.2.4}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.3 → xax-0.2.4}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.3 → xax-0.2.4}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.3 → xax-0.2.4}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.3 → xax-0.2.4}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.3"
15
+ __version__ = "0.2.4"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -1,6 +1,5 @@
1
1
  """Defines a dataclass for keeping track of the current training state."""
2
2
 
3
- import time
4
3
  from dataclasses import dataclass
5
4
  from typing import Literal, NotRequired, TypedDict, Unpack, cast
6
5
 
@@ -19,6 +18,8 @@ def _phase_to_int(phase: Phase) -> int:
19
18
 
20
19
 
21
20
  def _int_to_phase(i: int) -> Phase:
21
+ if i < 0 or i > 1:
22
+ raise ValueError(f"Invalid phase: {i}")
22
23
  return cast(Phase, ["train", "valid"][i])
23
24
 
24
25
 
@@ -27,8 +28,8 @@ class StateDict(TypedDict, total=False):
27
28
  num_samples: NotRequired[int | Array]
28
29
  num_valid_steps: NotRequired[int | Array]
29
30
  num_valid_samples: NotRequired[int | Array]
30
- start_time_s: NotRequired[float | Array]
31
31
  elapsed_time_s: NotRequired[float | Array]
32
+ valid_elapsed_time_s: NotRequired[float | Array]
32
33
  phase: NotRequired[Phase]
33
34
  _phase: NotRequired[int | Array]
34
35
 
@@ -43,24 +44,24 @@ class State:
43
44
  def num_steps(self) -> Array:
44
45
  return self._int32_arr[0]
45
46
 
46
- @property
47
- def num_samples(self) -> Array:
48
- return self._float32_arr[0]
49
-
50
47
  @property
51
48
  def num_valid_steps(self) -> Array:
52
49
  return self._int32_arr[1]
53
50
 
51
+ @property
52
+ def num_samples(self) -> Array:
53
+ return self._float32_arr[0]
54
+
54
55
  @property
55
56
  def num_valid_samples(self) -> Array:
56
57
  return self._float32_arr[1]
57
58
 
58
59
  @property
59
- def start_time_s(self) -> Array:
60
+ def elapsed_time_s(self) -> Array:
60
61
  return self._float32_arr[2]
61
62
 
62
63
  @property
63
- def elapsed_time_s(self) -> Array:
64
+ def valid_elapsed_time_s(self) -> Array:
64
65
  return self._float32_arr[3]
65
66
 
66
67
  @property
@@ -71,7 +72,7 @@ class State:
71
72
  def init_state(cls) -> "State":
72
73
  return cls(
73
74
  _int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
74
- _float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0], dtype=jnp.float32),
75
+ _float32_arr=jnp.array([0.0, 0.0, 0.0, 0.0], dtype=jnp.float32),
75
76
  )
76
77
 
77
78
  @property
@@ -97,10 +98,10 @@ class State:
97
98
  if "num_valid_samples" in kwargs:
98
99
  float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
99
100
 
100
- if "start_time_s" in kwargs:
101
- float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
102
101
  if "elapsed_time_s" in kwargs:
103
- float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
102
+ float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
103
+ if "valid_elapsed_time_s" in kwargs:
104
+ float32_arr = float32_arr.at[3].set(kwargs["valid_elapsed_time_s"])
104
105
 
105
106
  return State(
106
107
  _int32_arr=int32_arr,
@@ -110,11 +111,11 @@ class State:
110
111
  def to_dict(self) -> dict[str, int | float | str]:
111
112
  return {
112
113
  "num_steps": int(self.num_steps),
113
- "num_samples": int(self.num_samples),
114
114
  "num_valid_steps": int(self.num_valid_steps),
115
+ "num_samples": int(self.num_samples),
115
116
  "num_valid_samples": int(self.num_valid_samples),
116
- "start_time_s": float(self.start_time_s),
117
117
  "elapsed_time_s": float(self.elapsed_time_s),
118
+ "valid_elapsed_time_s": float(self.valid_elapsed_time_s),
118
119
  "phase": str(self.phase),
119
120
  }
120
121
 
@@ -126,9 +127,7 @@ class State:
126
127
  int32_arr = jnp.array(
127
128
  [
128
129
  d.get("num_steps", 0),
129
- d.get("num_samples", 0),
130
130
  d.get("num_valid_steps", 0),
131
- d.get("num_valid_samples", 0),
132
131
  d.get("_phase", 0),
133
132
  ],
134
133
  dtype=jnp.int32,
@@ -136,8 +135,10 @@ class State:
136
135
 
137
136
  float32_arr = jnp.array(
138
137
  [
139
- d.get("start_time_s", time.time()),
138
+ d.get("num_samples", 0),
139
+ d.get("num_valid_samples", 0),
140
140
  d.get("elapsed_time_s", 0.0),
141
+ d.get("valid_elapsed_time_s", 0.0),
141
142
  ],
142
143
  dtype=jnp.float32,
143
144
  )
@@ -157,7 +157,7 @@ class TensorboardLogger(LoggerImpl):
157
157
  writer = self.get_writer(line.state.phase)
158
158
 
159
159
  global_step = line.state.num_steps.item()
160
- walltime = (line.state.start_time_s + line.state.elapsed_time_s).item()
160
+ walltime = line.state.elapsed_time_s.item()
161
161
 
162
162
  for namespace, scalars in line.scalars.items():
163
163
  for scalar_key, scalar_value in scalars.items():
@@ -720,20 +720,21 @@ class TrainMixin(
720
720
 
721
721
  while not self.is_training_over(state):
722
722
  if self.valid_step_timer.is_valid_step(state):
723
- valid_batch = next(valid_pf)
723
+ with ContextTimer() as timer:
724
+ valid_batch = next(valid_pf)
725
+ output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
726
+ self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
727
+
724
728
  state = state.replace(
725
729
  phase="valid",
726
730
  num_valid_steps=state.num_valid_steps + 1,
727
731
  num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
732
+ valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
728
733
  )
729
734
 
730
- output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
731
- self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
732
-
733
- state = self.on_step_start(state)
734
- train_batch = next(train_pf)
735
-
736
735
  with ContextTimer() as timer:
736
+ state = self.on_step_start(state)
737
+ train_batch = next(train_pf)
737
738
  model_arr, opt_state, output, metrics = self.train_step(
738
739
  model_arr=model_arr,
739
740
  model_static=model_static,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes