xax 0.2.3__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.3/xax.egg-info → xax-0.2.5}/PKG-INFO +6 -7
  2. {xax-0.2.3 → xax-0.2.5}/setup.py +3 -7
  3. {xax-0.2.3 → xax-0.2.5}/xax/__init__.py +1 -1
  4. {xax-0.2.3 → xax-0.2.5}/xax/core/state.py +20 -19
  5. {xax-0.2.3 → xax-0.2.5}/xax/nn/export.py +2 -2
  6. {xax-0.2.3 → xax-0.2.5}/xax/task/loggers/tensorboard.py +1 -1
  7. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/train.py +8 -7
  8. {xax-0.2.3 → xax-0.2.5/xax.egg-info}/PKG-INFO +6 -7
  9. {xax-0.2.3 → xax-0.2.5}/xax.egg-info/requires.txt +3 -5
  10. {xax-0.2.3 → xax-0.2.5}/LICENSE +0 -0
  11. {xax-0.2.3 → xax-0.2.5}/MANIFEST.in +0 -0
  12. {xax-0.2.3 → xax-0.2.5}/README.md +0 -0
  13. {xax-0.2.3 → xax-0.2.5}/pyproject.toml +0 -0
  14. {xax-0.2.3 → xax-0.2.5}/setup.cfg +0 -0
  15. {xax-0.2.3 → xax-0.2.5}/xax/core/__init__.py +0 -0
  16. {xax-0.2.3 → xax-0.2.5}/xax/core/conf.py +0 -0
  17. {xax-0.2.3 → xax-0.2.5}/xax/nn/__init__.py +0 -0
  18. {xax-0.2.3 → xax-0.2.5}/xax/nn/embeddings.py +0 -0
  19. {xax-0.2.3 → xax-0.2.5}/xax/nn/equinox.py +0 -0
  20. {xax-0.2.3 → xax-0.2.5}/xax/nn/functions.py +0 -0
  21. {xax-0.2.3 → xax-0.2.5}/xax/nn/geom.py +0 -0
  22. {xax-0.2.3 → xax-0.2.5}/xax/nn/losses.py +0 -0
  23. {xax-0.2.3 → xax-0.2.5}/xax/nn/norm.py +0 -0
  24. {xax-0.2.3 → xax-0.2.5}/xax/nn/parallel.py +0 -0
  25. {xax-0.2.3 → xax-0.2.5}/xax/nn/ssm.py +0 -0
  26. {xax-0.2.3 → xax-0.2.5}/xax/py.typed +0 -0
  27. {xax-0.2.3 → xax-0.2.5}/xax/requirements-dev.txt +0 -0
  28. {xax-0.2.3 → xax-0.2.5}/xax/requirements.txt +0 -0
  29. {xax-0.2.3 → xax-0.2.5}/xax/task/__init__.py +0 -0
  30. {xax-0.2.3 → xax-0.2.5}/xax/task/base.py +0 -0
  31. {xax-0.2.3 → xax-0.2.5}/xax/task/launchers/__init__.py +0 -0
  32. {xax-0.2.3 → xax-0.2.5}/xax/task/launchers/base.py +0 -0
  33. {xax-0.2.3 → xax-0.2.5}/xax/task/launchers/cli.py +0 -0
  34. {xax-0.2.3 → xax-0.2.5}/xax/task/launchers/single_process.py +0 -0
  35. {xax-0.2.3 → xax-0.2.5}/xax/task/logger.py +0 -0
  36. {xax-0.2.3 → xax-0.2.5}/xax/task/loggers/__init__.py +0 -0
  37. {xax-0.2.3 → xax-0.2.5}/xax/task/loggers/callback.py +0 -0
  38. {xax-0.2.3 → xax-0.2.5}/xax/task/loggers/json.py +0 -0
  39. {xax-0.2.3 → xax-0.2.5}/xax/task/loggers/state.py +0 -0
  40. {xax-0.2.3 → xax-0.2.5}/xax/task/loggers/stdout.py +0 -0
  41. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/__init__.py +0 -0
  42. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/artifacts.py +0 -0
  43. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/checkpointing.py +0 -0
  44. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/compile.py +0 -0
  45. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/cpu_stats.py +0 -0
  46. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/gpu_stats.py +0 -0
  48. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/process.py +0 -0
  50. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/runnable.py +0 -0
  51. {xax-0.2.3 → xax-0.2.5}/xax/task/mixins/step_wrapper.py +0 -0
  52. {xax-0.2.3 → xax-0.2.5}/xax/task/script.py +0 -0
  53. {xax-0.2.3 → xax-0.2.5}/xax/task/task.py +0 -0
  54. {xax-0.2.3 → xax-0.2.5}/xax/utils/__init__.py +0 -0
  55. {xax-0.2.3 → xax-0.2.5}/xax/utils/data/__init__.py +0 -0
  56. {xax-0.2.3 → xax-0.2.5}/xax/utils/data/collate.py +0 -0
  57. {xax-0.2.3 → xax-0.2.5}/xax/utils/debugging.py +0 -0
  58. {xax-0.2.3 → xax-0.2.5}/xax/utils/experiments.py +0 -0
  59. {xax-0.2.3 → xax-0.2.5}/xax/utils/jax.py +0 -0
  60. {xax-0.2.3 → xax-0.2.5}/xax/utils/jaxpr.py +0 -0
  61. {xax-0.2.3 → xax-0.2.5}/xax/utils/logging.py +0 -0
  62. {xax-0.2.3 → xax-0.2.5}/xax/utils/numpy.py +0 -0
  63. {xax-0.2.3 → xax-0.2.5}/xax/utils/profile.py +0 -0
  64. {xax-0.2.3 → xax-0.2.5}/xax/utils/pytree.py +0 -0
  65. {xax-0.2.3 → xax-0.2.5}/xax/utils/tensorboard.py +0 -0
  66. {xax-0.2.3 → xax-0.2.5}/xax/utils/text.py +0 -0
  67. {xax-0.2.3 → xax-0.2.5}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.2.3 → xax-0.2.5}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.2.3 → xax-0.2.5}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.2.3 → xax-0.2.5}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.2.3 → xax-0.2.5}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.2.3 → xax-0.2.5}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,11 +31,10 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
- Provides-Extra: export
35
- Requires-Dist: orbax-export; extra == "export"
36
- Requires-Dist: tensorflow; extra == "export"
37
- Provides-Extra: flax
38
- Requires-Dist: flax; extra == "flax"
34
+ Provides-Extra: exportable
35
+ Requires-Dist: flax; extra == "exportable"
36
+ Requires-Dist: orbax-export; extra == "exportable"
37
+ Requires-Dist: tensorflow; extra == "exportable"
39
38
  Provides-Extra: all
40
39
  Requires-Dist: black; extra == "all"
41
40
  Requires-Dist: darglint; extra == "all"
@@ -45,9 +44,9 @@ Requires-Dist: pytest; extra == "all"
45
44
  Requires-Dist: types-pillow; extra == "all"
46
45
  Requires-Dist: types-psutil; extra == "all"
47
46
  Requires-Dist: types-requests; extra == "all"
47
+ Requires-Dist: flax; extra == "all"
48
48
  Requires-Dist: orbax-export; extra == "all"
49
49
  Requires-Dist: tensorflow; extra == "all"
50
- Requires-Dist: flax; extra == "all"
51
50
  Dynamic: author
52
51
  Dynamic: description
53
52
  Dynamic: description-content-type
@@ -15,14 +15,11 @@ with open("xax/requirements-dev.txt", "r", encoding="utf-8") as f:
15
15
  requirements_dev: list[str] = f.read().splitlines()
16
16
 
17
17
  requirements_export: list[str] = [
18
+ "flax",
18
19
  "orbax-export",
19
20
  "tensorflow",
20
21
  ]
21
22
 
22
- requirements_flax: list[str] = [
23
- "flax",
24
- ]
25
-
26
23
  with open("xax/__init__.py", "r", encoding="utf-8") as fh:
27
24
  version_re = re.search(r"^__version__ = \"([^\"]*)\"", fh.read(), re.MULTILINE)
28
25
  assert version_re is not None, "Could not find version in xax/__init__.py"
@@ -42,9 +39,8 @@ setup(
42
39
  tests_require=requirements_dev,
43
40
  extras_require={
44
41
  "dev": requirements_dev,
45
- "export": requirements_export,
46
- "flax": requirements_flax,
47
- "all": requirements_dev + requirements_export + requirements_flax,
42
+ "exportable": requirements_export,
43
+ "all": requirements_dev + requirements_export,
48
44
  },
49
45
  package_data={
50
46
  "xax": [
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.3"
15
+ __version__ = "0.2.5"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -1,6 +1,5 @@
1
1
  """Defines a dataclass for keeping track of the current training state."""
2
2
 
3
- import time
4
3
  from dataclasses import dataclass
5
4
  from typing import Literal, NotRequired, TypedDict, Unpack, cast
6
5
 
@@ -19,6 +18,8 @@ def _phase_to_int(phase: Phase) -> int:
19
18
 
20
19
 
21
20
  def _int_to_phase(i: int) -> Phase:
21
+ if i < 0 or i > 1:
22
+ raise ValueError(f"Invalid phase: {i}")
22
23
  return cast(Phase, ["train", "valid"][i])
23
24
 
24
25
 
@@ -27,8 +28,8 @@ class StateDict(TypedDict, total=False):
27
28
  num_samples: NotRequired[int | Array]
28
29
  num_valid_steps: NotRequired[int | Array]
29
30
  num_valid_samples: NotRequired[int | Array]
30
- start_time_s: NotRequired[float | Array]
31
31
  elapsed_time_s: NotRequired[float | Array]
32
+ valid_elapsed_time_s: NotRequired[float | Array]
32
33
  phase: NotRequired[Phase]
33
34
  _phase: NotRequired[int | Array]
34
35
 
@@ -43,24 +44,24 @@ class State:
43
44
  def num_steps(self) -> Array:
44
45
  return self._int32_arr[0]
45
46
 
46
- @property
47
- def num_samples(self) -> Array:
48
- return self._float32_arr[0]
49
-
50
47
  @property
51
48
  def num_valid_steps(self) -> Array:
52
49
  return self._int32_arr[1]
53
50
 
51
+ @property
52
+ def num_samples(self) -> Array:
53
+ return self._float32_arr[0]
54
+
54
55
  @property
55
56
  def num_valid_samples(self) -> Array:
56
57
  return self._float32_arr[1]
57
58
 
58
59
  @property
59
- def start_time_s(self) -> Array:
60
+ def elapsed_time_s(self) -> Array:
60
61
  return self._float32_arr[2]
61
62
 
62
63
  @property
63
- def elapsed_time_s(self) -> Array:
64
+ def valid_elapsed_time_s(self) -> Array:
64
65
  return self._float32_arr[3]
65
66
 
66
67
  @property
@@ -71,7 +72,7 @@ class State:
71
72
  def init_state(cls) -> "State":
72
73
  return cls(
73
74
  _int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
74
- _float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0], dtype=jnp.float32),
75
+ _float32_arr=jnp.array([0.0, 0.0, 0.0, 0.0], dtype=jnp.float32),
75
76
  )
76
77
 
77
78
  @property
@@ -88,19 +89,19 @@ class State:
88
89
  int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
89
90
 
90
91
  if "phase" in kwargs:
91
- int32_arr = int32_arr.at[3].set(_phase_to_int(kwargs["phase"]))
92
+ int32_arr = int32_arr.at[2].set(_phase_to_int(kwargs["phase"]))
92
93
  if "_phase" in kwargs:
93
- int32_arr = int32_arr.at[3].set(kwargs["_phase"])
94
+ int32_arr = int32_arr.at[2].set(kwargs["_phase"])
94
95
 
95
96
  if "num_samples" in kwargs:
96
97
  float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
97
98
  if "num_valid_samples" in kwargs:
98
99
  float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
99
100
 
100
- if "start_time_s" in kwargs:
101
- float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
102
101
  if "elapsed_time_s" in kwargs:
103
- float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
102
+ float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
103
+ if "valid_elapsed_time_s" in kwargs:
104
+ float32_arr = float32_arr.at[3].set(kwargs["valid_elapsed_time_s"])
104
105
 
105
106
  return State(
106
107
  _int32_arr=int32_arr,
@@ -110,11 +111,11 @@ class State:
110
111
  def to_dict(self) -> dict[str, int | float | str]:
111
112
  return {
112
113
  "num_steps": int(self.num_steps),
113
- "num_samples": int(self.num_samples),
114
114
  "num_valid_steps": int(self.num_valid_steps),
115
+ "num_samples": int(self.num_samples),
115
116
  "num_valid_samples": int(self.num_valid_samples),
116
- "start_time_s": float(self.start_time_s),
117
117
  "elapsed_time_s": float(self.elapsed_time_s),
118
+ "valid_elapsed_time_s": float(self.valid_elapsed_time_s),
118
119
  "phase": str(self.phase),
119
120
  }
120
121
 
@@ -126,9 +127,7 @@ class State:
126
127
  int32_arr = jnp.array(
127
128
  [
128
129
  d.get("num_steps", 0),
129
- d.get("num_samples", 0),
130
130
  d.get("num_valid_steps", 0),
131
- d.get("num_valid_samples", 0),
132
131
  d.get("_phase", 0),
133
132
  ],
134
133
  dtype=jnp.int32,
@@ -136,8 +135,10 @@ class State:
136
135
 
137
136
  float32_arr = jnp.array(
138
137
  [
139
- d.get("start_time_s", time.time()),
138
+ d.get("num_samples", 0),
139
+ d.get("num_valid_samples", 0),
140
140
  d.get("elapsed_time_s", 0.0),
141
+ d.get("valid_elapsed_time_s", 0.0),
141
142
  ],
142
143
  dtype=jnp.float32,
143
144
  )
@@ -14,8 +14,8 @@ try:
14
14
  from orbax.export import ExportManager, JaxModule, ServingConfig
15
15
  except ImportError as e:
16
16
  raise ImportError(
17
- "In order to export models, please install Xax with export dependencies, "
18
- "using 'xax[export]` to install the required dependencies."
17
+ "In order to export models, please install Xax with exportable dependencies, "
18
+ "using 'xax[exportable]` to install the required dependencies."
19
19
  ) from e
20
20
 
21
21
  logger = logging.getLogger(__name__)
@@ -157,7 +157,7 @@ class TensorboardLogger(LoggerImpl):
157
157
  writer = self.get_writer(line.state.phase)
158
158
 
159
159
  global_step = line.state.num_steps.item()
160
- walltime = (line.state.start_time_s + line.state.elapsed_time_s).item()
160
+ walltime = line.state.elapsed_time_s.item()
161
161
 
162
162
  for namespace, scalars in line.scalars.items():
163
163
  for scalar_key, scalar_value in scalars.items():
@@ -720,20 +720,21 @@ class TrainMixin(
720
720
 
721
721
  while not self.is_training_over(state):
722
722
  if self.valid_step_timer.is_valid_step(state):
723
- valid_batch = next(valid_pf)
723
+ with ContextTimer() as timer:
724
+ valid_batch = next(valid_pf)
725
+ output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
726
+ self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
727
+
724
728
  state = state.replace(
725
729
  phase="valid",
726
730
  num_valid_steps=state.num_valid_steps + 1,
727
731
  num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
732
+ valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
728
733
  )
729
734
 
730
- output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
731
- self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
732
-
733
- state = self.on_step_start(state)
734
- train_batch = next(train_pf)
735
-
736
735
  with ContextTimer() as timer:
736
+ state = self.on_step_start(state)
737
+ train_batch = next(train_pf)
737
738
  model_arr, opt_state, output, metrics = self.train_step(
738
739
  model_arr=model_arr,
739
740
  model_static=model_static,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,11 +31,10 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
- Provides-Extra: export
35
- Requires-Dist: orbax-export; extra == "export"
36
- Requires-Dist: tensorflow; extra == "export"
37
- Provides-Extra: flax
38
- Requires-Dist: flax; extra == "flax"
34
+ Provides-Extra: exportable
35
+ Requires-Dist: flax; extra == "exportable"
36
+ Requires-Dist: orbax-export; extra == "exportable"
37
+ Requires-Dist: tensorflow; extra == "exportable"
39
38
  Provides-Extra: all
40
39
  Requires-Dist: black; extra == "all"
41
40
  Requires-Dist: darglint; extra == "all"
@@ -45,9 +44,9 @@ Requires-Dist: pytest; extra == "all"
45
44
  Requires-Dist: types-pillow; extra == "all"
46
45
  Requires-Dist: types-psutil; extra == "all"
47
46
  Requires-Dist: types-requests; extra == "all"
47
+ Requires-Dist: flax; extra == "all"
48
48
  Requires-Dist: orbax-export; extra == "all"
49
49
  Requires-Dist: tensorflow; extra == "all"
50
- Requires-Dist: flax; extra == "all"
51
50
  Dynamic: author
52
51
  Dynamic: description
53
52
  Dynamic: description-content-type
@@ -23,9 +23,9 @@ pytest
23
23
  types-pillow
24
24
  types-psutil
25
25
  types-requests
26
+ flax
26
27
  orbax-export
27
28
  tensorflow
28
- flax
29
29
 
30
30
  [dev]
31
31
  black
@@ -37,9 +37,7 @@ types-pillow
37
37
  types-psutil
38
38
  types-requests
39
39
 
40
- [export]
40
+ [exportable]
41
+ flax
41
42
  orbax-export
42
43
  tensorflow
43
-
44
- [flax]
45
- flax
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes