xax 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.3"
15
+ __version__ = "0.2.5"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/core/state.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Defines a dataclass for keeping track of the current training state."""
2
2
 
3
- import time
4
3
  from dataclasses import dataclass
5
4
  from typing import Literal, NotRequired, TypedDict, Unpack, cast
6
5
 
@@ -19,6 +18,8 @@ def _phase_to_int(phase: Phase) -> int:
19
18
 
20
19
 
21
20
  def _int_to_phase(i: int) -> Phase:
21
+ if i < 0 or i > 1:
22
+ raise ValueError(f"Invalid phase: {i}")
22
23
  return cast(Phase, ["train", "valid"][i])
23
24
 
24
25
 
@@ -27,8 +28,8 @@ class StateDict(TypedDict, total=False):
27
28
  num_samples: NotRequired[int | Array]
28
29
  num_valid_steps: NotRequired[int | Array]
29
30
  num_valid_samples: NotRequired[int | Array]
30
- start_time_s: NotRequired[float | Array]
31
31
  elapsed_time_s: NotRequired[float | Array]
32
+ valid_elapsed_time_s: NotRequired[float | Array]
32
33
  phase: NotRequired[Phase]
33
34
  _phase: NotRequired[int | Array]
34
35
 
@@ -43,24 +44,24 @@ class State:
43
44
  def num_steps(self) -> Array:
44
45
  return self._int32_arr[0]
45
46
 
46
- @property
47
- def num_samples(self) -> Array:
48
- return self._float32_arr[0]
49
-
50
47
  @property
51
48
  def num_valid_steps(self) -> Array:
52
49
  return self._int32_arr[1]
53
50
 
51
+ @property
52
+ def num_samples(self) -> Array:
53
+ return self._float32_arr[0]
54
+
54
55
  @property
55
56
  def num_valid_samples(self) -> Array:
56
57
  return self._float32_arr[1]
57
58
 
58
59
  @property
59
- def start_time_s(self) -> Array:
60
+ def elapsed_time_s(self) -> Array:
60
61
  return self._float32_arr[2]
61
62
 
62
63
  @property
63
- def elapsed_time_s(self) -> Array:
64
+ def valid_elapsed_time_s(self) -> Array:
64
65
  return self._float32_arr[3]
65
66
 
66
67
  @property
@@ -71,7 +72,7 @@ class State:
71
72
  def init_state(cls) -> "State":
72
73
  return cls(
73
74
  _int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
74
- _float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0], dtype=jnp.float32),
75
+ _float32_arr=jnp.array([0.0, 0.0, 0.0, 0.0], dtype=jnp.float32),
75
76
  )
76
77
 
77
78
  @property
@@ -88,19 +89,19 @@ class State:
88
89
  int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
89
90
 
90
91
  if "phase" in kwargs:
91
- int32_arr = int32_arr.at[3].set(_phase_to_int(kwargs["phase"]))
92
+ int32_arr = int32_arr.at[2].set(_phase_to_int(kwargs["phase"]))
92
93
  if "_phase" in kwargs:
93
- int32_arr = int32_arr.at[3].set(kwargs["_phase"])
94
+ int32_arr = int32_arr.at[2].set(kwargs["_phase"])
94
95
 
95
96
  if "num_samples" in kwargs:
96
97
  float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
97
98
  if "num_valid_samples" in kwargs:
98
99
  float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
99
100
 
100
- if "start_time_s" in kwargs:
101
- float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
102
101
  if "elapsed_time_s" in kwargs:
103
- float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
102
+ float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
103
+ if "valid_elapsed_time_s" in kwargs:
104
+ float32_arr = float32_arr.at[3].set(kwargs["valid_elapsed_time_s"])
104
105
 
105
106
  return State(
106
107
  _int32_arr=int32_arr,
@@ -110,11 +111,11 @@ class State:
110
111
  def to_dict(self) -> dict[str, int | float | str]:
111
112
  return {
112
113
  "num_steps": int(self.num_steps),
113
- "num_samples": int(self.num_samples),
114
114
  "num_valid_steps": int(self.num_valid_steps),
115
+ "num_samples": int(self.num_samples),
115
116
  "num_valid_samples": int(self.num_valid_samples),
116
- "start_time_s": float(self.start_time_s),
117
117
  "elapsed_time_s": float(self.elapsed_time_s),
118
+ "valid_elapsed_time_s": float(self.valid_elapsed_time_s),
118
119
  "phase": str(self.phase),
119
120
  }
120
121
 
@@ -126,9 +127,7 @@ class State:
126
127
  int32_arr = jnp.array(
127
128
  [
128
129
  d.get("num_steps", 0),
129
- d.get("num_samples", 0),
130
130
  d.get("num_valid_steps", 0),
131
- d.get("num_valid_samples", 0),
132
131
  d.get("_phase", 0),
133
132
  ],
134
133
  dtype=jnp.int32,
@@ -136,8 +135,10 @@ class State:
136
135
 
137
136
  float32_arr = jnp.array(
138
137
  [
139
- d.get("start_time_s", time.time()),
138
+ d.get("num_samples", 0),
139
+ d.get("num_valid_samples", 0),
140
140
  d.get("elapsed_time_s", 0.0),
141
+ d.get("valid_elapsed_time_s", 0.0),
141
142
  ],
142
143
  dtype=jnp.float32,
143
144
  )
xax/nn/export.py CHANGED
@@ -14,8 +14,8 @@ try:
14
14
  from orbax.export import ExportManager, JaxModule, ServingConfig
15
15
  except ImportError as e:
16
16
  raise ImportError(
17
- "In order to export models, please install Xax with export dependencies, "
18
- "using 'xax[export]` to install the required dependencies."
17
+ "In order to export models, please install Xax with exportable dependencies, "
18
+ "using 'xax[exportable]` to install the required dependencies."
19
19
  ) from e
20
20
 
21
21
  logger = logging.getLogger(__name__)
@@ -157,7 +157,7 @@ class TensorboardLogger(LoggerImpl):
157
157
  writer = self.get_writer(line.state.phase)
158
158
 
159
159
  global_step = line.state.num_steps.item()
160
- walltime = (line.state.start_time_s + line.state.elapsed_time_s).item()
160
+ walltime = line.state.elapsed_time_s.item()
161
161
 
162
162
  for namespace, scalars in line.scalars.items():
163
163
  for scalar_key, scalar_value in scalars.items():
xax/task/mixins/train.py CHANGED
@@ -720,20 +720,21 @@ class TrainMixin(
720
720
 
721
721
  while not self.is_training_over(state):
722
722
  if self.valid_step_timer.is_valid_step(state):
723
- valid_batch = next(valid_pf)
723
+ with ContextTimer() as timer:
724
+ valid_batch = next(valid_pf)
725
+ output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
726
+ self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
727
+
724
728
  state = state.replace(
725
729
  phase="valid",
726
730
  num_valid_steps=state.num_valid_steps + 1,
727
731
  num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
732
+ valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
728
733
  )
729
734
 
730
- output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
731
- self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
732
-
733
- state = self.on_step_start(state)
734
- train_batch = next(train_pf)
735
-
736
735
  with ContextTimer() as timer:
736
+ state = self.on_step_start(state)
737
+ train_batch = next(train_pf)
737
738
  model_arr, opt_state, output, metrics = self.train_step(
738
739
  model_arr=model_arr,
739
740
  model_static=model_static,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,11 +31,10 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
- Provides-Extra: export
35
- Requires-Dist: orbax-export; extra == "export"
36
- Requires-Dist: tensorflow; extra == "export"
37
- Provides-Extra: flax
38
- Requires-Dist: flax; extra == "flax"
34
+ Provides-Extra: exportable
35
+ Requires-Dist: flax; extra == "exportable"
36
+ Requires-Dist: orbax-export; extra == "exportable"
37
+ Requires-Dist: tensorflow; extra == "exportable"
39
38
  Provides-Extra: all
40
39
  Requires-Dist: black; extra == "all"
41
40
  Requires-Dist: darglint; extra == "all"
@@ -45,9 +44,9 @@ Requires-Dist: pytest; extra == "all"
45
44
  Requires-Dist: types-pillow; extra == "all"
46
45
  Requires-Dist: types-psutil; extra == "all"
47
46
  Requires-Dist: types-requests; extra == "all"
47
+ Requires-Dist: flax; extra == "all"
48
48
  Requires-Dist: orbax-export; extra == "all"
49
49
  Requires-Dist: tensorflow; extra == "all"
50
- Requires-Dist: flax; extra == "all"
51
50
  Dynamic: author
52
51
  Dynamic: description
53
52
  Dynamic: description-content-type
@@ -1,14 +1,14 @@
1
- xax/__init__.py,sha256=P4q2IGkfpHaN3ZlGFiW0bzWm1spLSUyl0GEPvH8oITg,14225
1
+ xax/__init__.py,sha256=X_QqDNJir1wdsfRY1CU1F4mdCQMlMZnyqPtY8MM1ODU,14225
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
5
5
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
7
- xax/core/state.py,sha256=lA7A5HCm2Nwk4J0kJGlTIhqHYFvbuwHfdNzOhmjEW08,4453
7
+ xax/core/state.py,sha256=yO25lMoLCUTJlHyLzQxlDbsHC_GZ3HkrKAq5huA7AkU,4552
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
- xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
11
+ xax/nn/export.py,sha256=pRfM2B4hB2EvljysC6AjtgB_7Cn7JtaP3dhYU2stZtY,5545
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
13
  xax/nn/geom.py,sha256=rImNlkHWeoNcY7f84nknizJ6uzsrMhbAtKeb2xAWxNY,6215
14
14
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
@@ -29,7 +29,7 @@ xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,
29
29
  xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
30
30
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
31
31
  xax/task/loggers/stdout.py,sha256=ERLFrYe61hSSztzyxBRseobHQR72YFYjEd2i_hOeJ20,6595
32
- xax/task/loggers/tensorboard.py,sha256=3ohI6STgSCbU8oyeiH_f3QyLVF_zO_6dwjn0ns59rUU,8334
32
+ xax/task/loggers/tensorboard.py,sha256=KFlsK0zD2ubDqAXYL4Ds7NQ9F-Ke-PHwfhLOYsGcbw4,8306
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
35
  xax/task/mixins/checkpointing.py,sha256=8Hi-2G0EA5OFRjgiOutlk7HgkD5b-0GHazOAYxnGytM,11409
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
41
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=lMHCnxsbZJbwK3esL5S3cJ0Jf5Qx19Y4pm3A7NY-TIE,31064
44
+ xax/task/mixins/train.py,sha256=XcetJ0MppV_RDhgg1M9_d9heEXo-zeN_FS3MyczeBBU,31219
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
47
  xax/utils/experiments.py,sha256=d2H63ECtVOKySMUMrQRqq4kcuZpoXqo-L931usDVAhE,29903
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.3.dist-info/METADATA,sha256=ukAnG444wnzRpXgmHSrs7RKJ-UQvOdl6ZE2ZrN0w4Yg,1882
63
- xax-0.2.3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.2.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.3.dist-info/RECORD,,
61
+ xax-0.2.5.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.5.dist-info/METADATA,sha256=4RBxZF_P0cg-a6QUNS9urvzc4BGGfoedqMrnP0L6Ksk,1879
63
+ xax-0.2.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.2.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.5.dist-info/RECORD,,
File without changes