xax 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.11"
15
+ __version__ = "0.2.13"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/core/state.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Defines a dataclass for keeping track of the current training state."""
2
2
 
3
+ import time
3
4
  from dataclasses import dataclass
4
5
  from typing import Literal, NotRequired, TypedDict, Unpack, cast
5
6
 
@@ -28,6 +29,7 @@ class StateDict(TypedDict, total=False):
28
29
  num_samples: NotRequired[int | Array]
29
30
  num_valid_steps: NotRequired[int | Array]
30
31
  num_valid_samples: NotRequired[int | Array]
32
+ start_time_s: NotRequired[float | Array]
31
33
  elapsed_time_s: NotRequired[float | Array]
32
34
  valid_elapsed_time_s: NotRequired[float | Array]
33
35
  phase: NotRequired[Phase]
@@ -57,13 +59,17 @@ class State:
57
59
  return self._float32_arr[1]
58
60
 
59
61
  @property
60
- def elapsed_time_s(self) -> Array:
62
+ def start_time_s(self) -> Array:
61
63
  return self._float32_arr[2]
62
64
 
63
65
  @property
64
- def valid_elapsed_time_s(self) -> Array:
66
+ def elapsed_time_s(self) -> Array:
65
67
  return self._float32_arr[3]
66
68
 
69
+ @property
70
+ def valid_elapsed_time_s(self) -> Array:
71
+ return self._float32_arr[4]
72
+
67
73
  @property
68
74
  def phase(self) -> Phase:
69
75
  return _int_to_phase(self._int32_arr[2].item())
@@ -72,7 +78,7 @@ class State:
72
78
  def init_state(cls) -> "State":
73
79
  return cls(
74
80
  _int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
75
- _float32_arr=jnp.array([0.0, 0.0, 0.0, 0.0], dtype=jnp.float32),
81
+ _float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0, 0.0], dtype=jnp.float32),
76
82
  )
77
83
 
78
84
  @property
@@ -98,10 +104,12 @@ class State:
98
104
  if "num_valid_samples" in kwargs:
99
105
  float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
100
106
 
107
+ if "start_time_s" in kwargs:
108
+ float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
101
109
  if "elapsed_time_s" in kwargs:
102
- float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
110
+ float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
103
111
  if "valid_elapsed_time_s" in kwargs:
104
- float32_arr = float32_arr.at[3].set(kwargs["valid_elapsed_time_s"])
112
+ float32_arr = float32_arr.at[4].set(kwargs["valid_elapsed_time_s"])
105
113
 
106
114
  return State(
107
115
  _int32_arr=int32_arr,
@@ -110,12 +118,13 @@ class State:
110
118
 
111
119
  def to_dict(self) -> dict[str, int | float | str]:
112
120
  return {
113
- "num_steps": int(self.num_steps),
114
- "num_valid_steps": int(self.num_valid_steps),
115
- "num_samples": int(self.num_samples),
116
- "num_valid_samples": int(self.num_valid_samples),
117
- "elapsed_time_s": float(self.elapsed_time_s),
118
- "valid_elapsed_time_s": float(self.valid_elapsed_time_s),
121
+ "num_steps": int(self.num_steps.item()),
122
+ "num_valid_steps": int(self.num_valid_steps.item()),
123
+ "num_samples": int(self.num_samples.item()),
124
+ "num_valid_samples": int(self.num_valid_samples.item()),
125
+ "start_time_s": float(self.start_time_s.item()),
126
+ "elapsed_time_s": float(self.elapsed_time_s.item()),
127
+ "valid_elapsed_time_s": float(self.valid_elapsed_time_s.item()),
119
128
  "phase": str(self.phase),
120
129
  }
121
130
 
@@ -137,6 +146,7 @@ class State:
137
146
  [
138
147
  d.get("num_samples", 0),
139
148
  d.get("num_valid_samples", 0),
149
+ d.get("start_time_s", time.time()),
140
150
  d.get("elapsed_time_s", 0.0),
141
151
  d.get("valid_elapsed_time_s", 0.0),
142
152
  ],
@@ -160,7 +160,7 @@ class TensorboardLogger(LoggerImpl):
160
160
  writer = self.get_writer(line.state.phase)
161
161
 
162
162
  global_step = line.state.num_steps.item()
163
- walltime = line.state.elapsed_time_s.item()
163
+ walltime = line.state.start_time_s.item() + line.state.elapsed_time_s.item()
164
164
 
165
165
  for namespace, scalars in line.scalars.items():
166
166
  for scalar_key, scalar_value in scalars.items():
xax/task/mixins/train.py CHANGED
@@ -120,7 +120,7 @@ class ValidStepTimer:
120
120
  self.last_valid_time = state.elapsed_time_s.item()
121
121
  self.last_valid_step = state.num_steps.item()
122
122
 
123
- def is_valid_step(self, state: State) -> bool:
123
+ def __call__(self, state: State) -> bool:
124
124
  if state.num_steps < self.valid_first_n_steps and state.num_valid_steps < self.valid_first_n_steps:
125
125
  return True
126
126
 
@@ -130,15 +130,18 @@ class ValidStepTimer:
130
130
 
131
131
  # Step-based validation.
132
132
  valid_every_n_steps = self.valid_every_n_steps
133
- if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
133
+ if valid_every_n_steps is not None and (
134
+ state.num_steps >= valid_every_n_steps + self.last_valid_step
135
+ or state.num_valid_steps >= valid_every_n_steps + self.last_valid_step
136
+ ):
134
137
  self._reset(state)
135
138
  return True
136
139
 
137
140
  # Time-based validation.
138
141
  valid_every_n_seconds = self.valid_every_n_seconds
139
- if (
140
- valid_every_n_seconds is not None
141
- and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
142
+ if valid_every_n_seconds is not None and (
143
+ state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
144
+ or state.valid_elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
142
145
  ):
143
146
  self._reset(state)
144
147
  return True
@@ -146,7 +149,10 @@ class ValidStepTimer:
146
149
  # Time-based validation for first validation step.
147
150
  if self.first_valid_step_flag:
148
151
  valid_first_n_seconds = self.valid_first_n_seconds
149
- if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
152
+ if valid_first_n_seconds is not None and (
153
+ state.elapsed_time_s.item() >= valid_first_n_seconds
154
+ or state.valid_elapsed_time_s.item() >= valid_first_n_seconds
155
+ ):
150
156
  self._reset(state)
151
157
  self.first_valid_step_flag = False
152
158
  return True
@@ -722,21 +728,27 @@ class TrainMixin(
722
728
  model_arr, model_static = eqx.partition(model, self.model_partition_fn)
723
729
 
724
730
  while not self.is_training_over(state):
725
- if self.valid_step_timer.is_valid_step(state):
731
+ valid_step = self.valid_step_timer(state)
732
+
733
+ if valid_step:
726
734
  with ContextTimer() as timer:
735
+ state = state.replace(phase="valid")
727
736
  valid_batch = next(valid_pf)
728
737
  output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
729
738
  self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
730
739
 
740
+ state = state.replace(
741
+ num_valid_steps=state.num_valid_steps + 1,
742
+ num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
743
+ )
744
+
731
745
  state = state.replace(
732
- phase="valid",
733
- num_valid_steps=state.num_valid_steps + 1,
734
- num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
735
746
  valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
736
747
  )
737
748
 
738
749
  with ContextTimer() as timer:
739
750
  state = self.on_step_start(state)
751
+ state = state.replace(phase="train")
740
752
  train_batch = next(train_pf)
741
753
  model_arr, opt_state, output, metrics = self.train_step(
742
754
  model_arr=model_arr,
@@ -748,15 +760,17 @@ class TrainMixin(
748
760
  )
749
761
  self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
750
762
 
763
+ state = state.replace(
764
+ num_steps=state.num_steps + 1,
765
+ num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
766
+ )
767
+
768
+ state = self.on_step_end(state)
769
+
751
770
  state = state.replace(
752
- phase="train",
753
- num_steps=state.num_steps + 1,
754
- num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
755
771
  elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
756
772
  )
757
773
 
758
- state = self.on_step_end(state)
759
-
760
774
  if self.should_checkpoint(state):
761
775
  model = eqx.combine(model_arr, model_static)
762
776
  self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.11
3
+ Version: 0.2.13
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,10 +1,10 @@
1
- xax/__init__.py,sha256=S4p0bL4JmuLyhFkGHlXlugXk-ckbnWtSw1_6r9E0qrI,15510
1
+ xax/__init__.py,sha256=33wIwGeXDFReg2ZnFqUHfSybj5cKyMqnI8ncj8-9yVg,15510
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
5
5
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
7
- xax/core/state.py,sha256=yO25lMoLCUTJlHyLzQxlDbsHC_GZ3HkrKAq5huA7AkU,4552
7
+ xax/core/state.py,sha256=KsNMnM_RgsZ2Ntc2pp4Fi6zG4rZb_89-kqmyGxDvyRg,4974
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=JZuSApD4bL0UK5W1nrQtucWYvNWUha07J6LTLk_RX-Y,4910
@@ -29,7 +29,7 @@ xax/task/loggers/callback.py,sha256=zQuV1xCvz47Q3UQqP1D5mBhbVzptvmPR_7hX25vqSk0,
29
29
  xax/task/loggers/json.py,sha256=6A5wL7kspsXnpPhI_vu0scgd2Z2-WLhw4gbBFm7eZMM,4377
30
30
  xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,1407
31
31
  xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
32
- xax/task/loggers/tensorboard.py,sha256=sdsA8GjZG5JQpoAxNDRr_bGvqN8Olgj_almZBb2K5F8,8850
32
+ xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=Ma7fwsp-SA1w6GcuBSskszj5TB83yxYJm4Ns_EnqkI4,3018
35
35
  xax/task/mixins/checkpointing.py,sha256=zqospBFnTbGt_iriiduVfXazINPbzWpwmIs91KAniMY,10147
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
41
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=PUVN2OsJpQppIzb4yaULT-C-0ocr1aGbPY-LrNJ2AVY,31322
44
+ xax/task/mixins/train.py,sha256=_QoxSDMW6nmpH82Un2LDsVIBg9KIx8npRwSjY4TEGYA,31830
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
47
  xax/utils/experiments.py,sha256=bj8BftSHT3fFzfiJ0Co0WvqWo0rUS8kQnQYpVvH8FTM,29942
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.11.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.11.dist-info/METADATA,sha256=MLwHl-vblIYvbpUZ5ylMDjwejKLNOnJK_55JwVNPVH8,1880
63
- xax-0.2.11.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
64
- xax-0.2.11.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.11.dist-info/RECORD,,
61
+ xax-0.2.13.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.13.dist-info/METADATA,sha256=-foHRw3ph7yxBmMmjO_oqZqwvdEROYTH4Drc9P58ujI,1880
63
+ xax-0.2.13.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
64
+ xax-0.2.13.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.13.dist-info/RECORD,,
File without changes