xax 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.5"
15
+ __version__ = "0.2.6"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/task/logger.py CHANGED
@@ -521,7 +521,8 @@ class LoggerImpl(ABC):
521
521
  Returns:
522
522
  If the logger should log the current step.
523
523
  """
524
- return self.tickers[state.phase].tick(state.elapsed_time_s.item())
524
+ elapsed_time = state.elapsed_time_s.item() if state.phase == "train" else state.valid_elapsed_time_s.item()
525
+ return self.tickers[state.phase].tick(elapsed_time)
525
526
 
526
527
 
527
528
  class ToastHandler(logging.Handler):
xax/task/mixins/train.py CHANGED
@@ -115,19 +115,22 @@ class ValidStepTimer:
115
115
  self.last_valid_time: float | None = None
116
116
  self.last_valid_step: int | None = None
117
117
 
118
+ def _reset(self, state: State) -> None:
119
+ self.last_valid_time = state.elapsed_time_s.item()
120
+ self.last_valid_step = state.num_steps.item()
121
+
118
122
  def is_valid_step(self, state: State) -> bool:
119
123
  if state.num_steps < self.valid_first_n_steps:
120
124
  return True
121
125
 
122
126
  if self.last_valid_time is None or self.last_valid_step is None:
123
- self.last_valid_time = state.elapsed_time_s.item()
124
- self.last_valid_step = state.num_steps.item()
127
+ self._reset(state)
125
128
  return False
126
129
 
127
130
  # Step-based validation.
128
131
  valid_every_n_steps = self.valid_every_n_steps
129
132
  if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
130
- self.last_valid_step = state.num_steps.item()
133
+ self._reset(state)
131
134
  return True
132
135
 
133
136
  # Time-based validation.
@@ -136,14 +139,14 @@ class ValidStepTimer:
136
139
  valid_every_n_seconds is not None
137
140
  and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
138
141
  ):
139
- self.last_valid_time = state.elapsed_time_s.item()
142
+ self._reset(state)
140
143
  return True
141
144
 
142
145
  # Time-based validation for first validation step.
143
146
  if self.first_valid_step_flag:
144
147
  valid_first_n_seconds = self.valid_first_n_seconds
145
148
  if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
146
- self.last_valid_time = state.elapsed_time_s.item()
149
+ self._reset(state)
147
150
  self.first_valid_step_flag = False
148
151
  return True
149
152
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=X_QqDNJir1wdsfRY1CU1F4mdCQMlMZnyqPtY8MM1ODU,14225
1
+ xax/__init__.py,sha256=pvOmycOUli9h53XMTlinjJdY4zzw_fvWoF05UQwoItI,14225
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -17,7 +17,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
17
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  xax/task/base.py,sha256=OnXi2hiKPGwt6ng1dutnoQSiw7lEiWFlC_vx99_JsbQ,7694
20
- xax/task/logger.py,sha256=y4PGfMqKbfvPk8WCzr9MOsgG2X9E61KgeBVOYp-9kOY,40875
20
+ xax/task/logger.py,sha256=gE67AaPCfU_1FpxY3t0yNRrIVgtp8Sax9UyOqFYMtzM,40976
21
21
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
22
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
23
23
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
41
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=XcetJ0MppV_RDhgg1M9_d9heEXo-zeN_FS3MyczeBBU,31219
44
+ xax/task/mixins/train.py,sha256=X3DIMPqpMRAJvFeeDaw5qFkH70R6yJelBAD8KHHwkUc,31196
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
47
  xax/utils/experiments.py,sha256=d2H63ECtVOKySMUMrQRqq4kcuZpoXqo-L931usDVAhE,29903
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.5.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.5.dist-info/METADATA,sha256=4RBxZF_P0cg-a6QUNS9urvzc4BGGfoedqMrnP0L6Ksk,1879
63
- xax-0.2.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.2.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.5.dist-info/RECORD,,
61
+ xax-0.2.6.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.6.dist-info/METADATA,sha256=g--2dJ3PIYDOWURjNhVmkv5B_VUcI2xY0iBa4oGuuVc,1879
63
+ xax-0.2.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.2.6.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.6.dist-info/RECORD,,
File without changes