xax 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.15"
15
+ __version__ = "0.1.16"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/task/logger.py CHANGED
@@ -205,6 +205,12 @@ def as_numpy(array: Array) -> np.ndarray:
205
205
  return np.array(array)
206
206
 
207
207
 
208
+ @dataclass(kw_only=True)
209
+ class LogString:
210
+ value: str
211
+ secondary: bool
212
+
213
+
208
214
  @dataclass(kw_only=True)
209
215
  class LogImage:
210
216
  image: PILImage
@@ -223,6 +229,12 @@ class LogVideo:
223
229
  fps: int
224
230
 
225
231
 
232
+ @dataclass(kw_only=True)
233
+ class LogScalar:
234
+ value: Number
235
+ secondary: bool
236
+
237
+
226
238
  @dataclass(kw_only=True)
227
239
  class LogDistribution:
228
240
  mean: Number
@@ -243,10 +255,10 @@ class LogHistogram:
243
255
  @dataclass(kw_only=True)
244
256
  class LogLine:
245
257
  state: State
246
- scalars: dict[str, dict[str, Number]]
258
+ scalars: dict[str, dict[str, LogScalar]]
247
259
  distributions: dict[str, dict[str, LogDistribution]]
248
260
  histograms: dict[str, dict[str, LogHistogram]]
249
- strings: dict[str, dict[str, str]]
261
+ strings: dict[str, dict[str, LogString]]
250
262
  images: dict[str, dict[str, LogImage]]
251
263
  videos: dict[str, dict[str, LogVideo]]
252
264
 
@@ -515,10 +527,10 @@ class Logger:
515
527
  """Defines an intermediate container which holds values to log somewhere else."""
516
528
 
517
529
  def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
518
- self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
530
+ self.scalars: dict[str, dict[str, Callable[[], LogScalar]]] = defaultdict(dict)
519
531
  self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
520
532
  self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
521
- self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
533
+ self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
522
534
  self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
523
535
  self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
524
536
  self.default_namespace = default_namespace
@@ -616,13 +628,23 @@ class Logger:
616
628
  def resolve_namespace(self, namespace: str | None = None) -> str:
617
629
  return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
618
630
 
619
- def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
631
+ def log_scalar(
632
+ self,
633
+ key: str,
634
+ value: Callable[[], Number] | Number,
635
+ *,
636
+ namespace: str | None = None,
637
+ secondary: bool = False,
638
+ ) -> None:
620
639
  """Logs a scalar value.
621
640
 
622
641
  Args:
623
642
  key: The key being logged
624
643
  value: The scalar value being logged
625
644
  namespace: An optional logging namespace
645
+ secondary: If set, treat this as a secondary value (meaning, it is
646
+ less important than other values, and some downstream loggers
647
+ will not display it)
626
648
  """
627
649
  if not self.active:
628
650
  raise RuntimeError("The logger is not active")
@@ -632,11 +654,11 @@ class Logger:
632
654
  assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
633
655
 
634
656
  @functools.lru_cache(maxsize=None)
635
- def scalar_future() -> Number:
657
+ def scalar_future() -> LogScalar:
636
658
  with ContextTimer() as timer:
637
659
  value_concrete = value() if callable(value) else value
638
660
  logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
639
- return value_concrete
661
+ return LogScalar(value=value_concrete, secondary=secondary)
640
662
 
641
663
  self.scalars[namespace][key] = scalar_future
642
664
 
@@ -770,21 +792,31 @@ class Logger:
770
792
 
771
793
  self.histograms[namespace][key] = histogram_future
772
794
 
773
- def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
795
+ def log_string(
796
+ self,
797
+ key: str,
798
+ value: Callable[[], str] | str,
799
+ *,
800
+ namespace: str | None = None,
801
+ secondary: bool = False,
802
+ ) -> None:
774
803
  """Logs a string value.
775
804
 
776
805
  Args:
777
806
  key: The key being logged
778
807
  value: The string value being logged
779
808
  namespace: An optional logging namespace
809
+ secondary: If set, treat this as a secondary value (meaning, it is
810
+ less important than other values, and some downstream loggers
811
+ will not display it)
780
812
  """
781
813
  if not self.active:
782
814
  raise RuntimeError("The logger is not active")
783
815
  namespace = self.resolve_namespace(namespace)
784
816
 
785
817
  @functools.lru_cache(maxsize=None)
786
- def value_future() -> str:
787
- return value() if callable(value) else value
818
+ def value_future() -> LogString:
819
+ return LogString(value=value() if callable(value) else value, secondary=secondary)
788
820
 
789
821
  self.strings[namespace][key] = value_future
790
822
 
xax/task/loggers/json.py CHANGED
@@ -3,11 +3,19 @@
3
3
  import json
4
4
  import sys
5
5
  from dataclasses import asdict
6
- from typing import Any, Literal, TextIO
6
+ from typing import Any, Literal, Mapping, TextIO
7
7
 
8
8
  from jaxtyping import Array
9
9
 
10
- from xax.task.logger import LogError, LoggerImpl, LogLine, LogPing, LogStatus
10
+ from xax.task.logger import (
11
+ LogError,
12
+ LoggerImpl,
13
+ LogLine,
14
+ LogPing,
15
+ LogScalar,
16
+ LogStatus,
17
+ LogString,
18
+ )
11
19
 
12
20
 
13
21
  def get_json_value(value: Any) -> Any: # noqa: ANN401
@@ -61,14 +69,14 @@ class JsonLogger(LoggerImpl):
61
69
  def get_json(self, line: LogLine) -> str:
62
70
  data: dict = {"state": asdict(line.state)}
63
71
 
64
- def add_logs(log: dict[str, dict[str, Any]], data: dict) -> None:
72
+ def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
65
73
  for namespace, values in log.items():
66
74
  if self.remove_unicode_from_namespaces:
67
75
  namespace = namespace.encode("ascii", errors="ignore").decode("ascii").strip()
68
76
  if namespace not in data:
69
77
  data[namespace] = {}
70
78
  for k, v in values.items():
71
- data[namespace][k] = get_json_value(v)
79
+ data[namespace][k] = get_json_value(v.value)
72
80
 
73
81
  add_logs(line.scalars, data)
74
82
  add_logs(line.strings, data)
@@ -4,11 +4,20 @@ import datetime
4
4
  import logging
5
5
  import sys
6
6
  from collections import deque
7
- from typing import Any, Deque, TextIO
7
+ from typing import Any, Deque, Mapping, TextIO
8
8
 
9
9
  from jaxtyping import Array
10
10
 
11
- from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
11
+ from xax.task.logger import (
12
+ LogError,
13
+ LogErrorSummary,
14
+ LoggerImpl,
15
+ LogLine,
16
+ LogPing,
17
+ LogScalar,
18
+ LogStatus,
19
+ LogString,
20
+ )
12
21
  from xax.utils.text import Color, colored, format_timedelta
13
22
 
14
23
 
@@ -95,20 +104,17 @@ class StdoutLogger(LoggerImpl):
95
104
  def write_log_window(self, line: LogLine) -> None:
96
105
  namespace_to_lines: dict[str, dict[str, str]] = {}
97
106
 
98
- def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
107
+ def add_logs(
108
+ log: Mapping[str, Mapping[str, LogScalar | LogString]],
109
+ namespace_to_lines: dict[str, dict[str, str]],
110
+ ) -> None:
99
111
  for namespace, values in log.items():
100
- if not self.log_timers and namespace.startswith("⌛"):
101
- continue
102
- if not self.log_perf and namespace.startswith("🔧"):
103
- continue
104
- if not self.log_optim and namespace.startswith("📉"):
105
- continue
106
- if not self.log_fp and namespace.startswith("⚖️"):
107
- continue
108
- if namespace not in namespace_to_lines:
109
- namespace_to_lines[namespace] = {}
110
112
  for k, v in values.items():
111
- v_str = as_str(v, self.precision)
113
+ if v.secondary:
114
+ continue
115
+ if namespace not in namespace_to_lines:
116
+ namespace_to_lines[namespace] = {}
117
+ v_str = as_str(v.value, self.precision)
112
118
  namespace_to_lines[namespace][k] = v_str
113
119
 
114
120
  add_logs(line.scalars, namespace_to_lines)
@@ -116,9 +122,8 @@ class StdoutLogger(LoggerImpl):
116
122
  if not namespace_to_lines:
117
123
  return
118
124
 
119
- self.write_fp.write("\n")
120
125
  for namespace, lines in sorted(namespace_to_lines.items()):
121
- self.write_fp.write(f"{colored(namespace, 'cyan', bold=True)}\n")
126
+ self.write_fp.write(f"\n{colored(namespace, 'cyan', bold=True)}\n")
122
127
  for k, v in lines.items():
123
128
  self.write_fp.write(f" ↪ {k}: {v}\n")
124
129
 
@@ -158,7 +158,7 @@ class TensorboardLogger(LoggerImpl):
158
158
  for scalar_key, scalar_value in scalars.items():
159
159
  writer.add_scalar(
160
160
  f"{namespace}/{scalar_key}",
161
- as_float(scalar_value),
161
+ as_float(scalar_value.value),
162
162
  global_step=line.state.num_steps,
163
163
  walltime=walltime,
164
164
  )
@@ -192,7 +192,7 @@ class TensorboardLogger(LoggerImpl):
192
192
  for string_key, string_value in strings.items():
193
193
  writer.add_text(
194
194
  f"{namespace}/{string_key}",
195
- string_value,
195
+ string_value.value,
196
196
  global_step=line.state.num_steps,
197
197
  walltime=walltime,
198
198
  )
@@ -248,15 +248,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
248
248
  stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
249
249
 
250
250
  if stats is not None:
251
- self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
252
- self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
253
- self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
254
- self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
255
- self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
256
- self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
257
- self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
258
- self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
259
- self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
260
- self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
251
+ self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu", secondary=True)
252
+ self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu", secondary=True)
253
+ self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu", secondary=True)
254
+ self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem", secondary=True)
255
+ self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem", secondary=True)
256
+ self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem", secondary=True)
257
+ self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem", secondary=True)
258
+ self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem", secondary=True)
259
+ self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem", secondary=True)
260
+ self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem", secondary=True)
261
261
 
262
262
  return state
@@ -264,8 +264,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
264
264
  for gpu_stat in stats.values():
265
265
  if gpu_stat is None:
266
266
  continue
267
- self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
268
- self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
269
- self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
267
+ self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu", secondary=True)
268
+ self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu", secondary=True)
269
+ self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu", secondary=True)
270
270
 
271
271
  return state
xax/task/mixins/train.py CHANGED
@@ -218,7 +218,12 @@ class TrainMixin(
218
218
  return state.replace(elapsed_time_s=time.time() - state.start_time_s)
219
219
 
220
220
  def log_train_step(
221
- self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
221
+ self,
222
+ model: PyTree,
223
+ batch: Batch,
224
+ output: Output,
225
+ metrics: FrozenDict[str, Array],
226
+ state: State,
222
227
  ) -> None:
223
228
  """Override this function to do logging during the training phase.
224
229
 
@@ -234,7 +239,12 @@ class TrainMixin(
234
239
  """
235
240
 
236
241
  def log_valid_step(
237
- self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
242
+ self,
243
+ model: PyTree,
244
+ batch: Batch,
245
+ output: Output,
246
+ metrics: FrozenDict[str, Array],
247
+ state: State,
238
248
  ) -> None:
239
249
  """Override this function to do logging during the validation phase.
240
250
 
@@ -252,12 +262,20 @@ class TrainMixin(
252
262
  def log_state_timers(self, state: State) -> None:
253
263
  timer = self.state_timers[state.phase]
254
264
  timer.step(state)
255
- for ns, d in timer.log_dict().items():
256
- for k, v in d.items():
257
- self.logger.log_scalar(k, v, namespace=ns)
265
+ for k, v in timer.log_dict().items():
266
+ if isinstance(v, tuple):
267
+ v, secondary = v
268
+ else:
269
+ secondary = False
270
+ self.logger.log_scalar(k, v, namespace="⌛ timers", secondary=secondary)
258
271
 
259
272
  def log_step(
260
- self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
273
+ self,
274
+ model: PyTree,
275
+ batch: Batch,
276
+ output: Output,
277
+ metrics: FrozenDict[str, Array],
278
+ state: State,
261
279
  ) -> None:
262
280
  phase = state.phase
263
281
 
xax/utils/experiments.py CHANGED
@@ -114,28 +114,15 @@ class StateTimer:
114
114
  self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
115
115
  self.iter_timer.step(cur_time)
116
116
 
117
- def log_dict(self) -> dict[str, dict[str, int | float]]:
118
- logs: dict[str, dict[str, int | float]] = {}
119
-
120
- # Logs step statistics.
121
- logs[" steps"] = {
122
- "total": self.step_timer.steps,
123
- "per-second": self.step_timer.steps_per_second,
124
- }
125
-
126
- # Logs sample statistics.
127
- logs["⌛ samples"] = {
128
- "total": self.sample_timer.steps,
129
- "per-second": self.sample_timer.steps_per_second,
117
+ def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
118
+ return {
119
+ "steps": (self.step_timer.steps, True),
120
+ "steps/second": self.step_timer.steps_per_second,
121
+ "samples": (self.sample_timer.steps, True),
122
+ "samples/second": (self.sample_timer.steps_per_second, True),
123
+ "dt": self.iter_timer.iter_seconds,
130
124
  }
131
125
 
132
- # Logs full iteration statistics.
133
- logs["⌛ dt"] = {
134
- "iter": self.iter_timer.iter_seconds,
135
- }
136
-
137
- return logs
138
-
139
126
 
140
127
  class IntervalTicker:
141
128
  def __init__(self, interval: float) -> None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.15
3
+ Version: 0.1.16
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=bV2mTcuiVaVNvwgbDgg7dKDkMeuyA0mqF0muU5KZHeg,14104
1
+ xax/__init__.py,sha256=pSWV5RtPBJynHr7dCqscbnMkETZPUyw8D6MHK4CuS90,14104
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -17,7 +17,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
17
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
20
- xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
20
+ xax/task/logger.py,sha256=Upx7cCZvaVIs75CHTfIzYmsuaFRsGu0FvziTZuazT4k,37083
21
21
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
22
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
23
23
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -26,25 +26,25 @@ xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,140
26
26
  xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
27
27
  xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
29
- xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
29
+ xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
30
30
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
31
- xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,6788
32
- xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
31
+ xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,6564
32
+ xax/task/loggers/tensorboard.py,sha256=HjR-wiCWe0z3nivRzxEZIltzSzka1828bwxWVmMU5Sk,7718
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
35
  xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
36
36
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
- xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
37
+ xax/task/mixins/cpu_stats.py,sha256=vAjEc3HpPnl56m7vshYX0dXAHJrB98DzVdsYSRqQllc,9371
38
38
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
39
- xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
39
+ xax/task/mixins/gpu_stats.py,sha256=4HU6teEDlqMitLbSx7fbyL4qBJ0PgGy0Ly_Pzife8yo,8795
40
40
  xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
41
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=1hmUx1HIL8HKfwOnupS3Knsw1CiK2YCbIQnUTYyDEms,26157
44
+ xax/task/mixins/train.py,sha256=4Xr8b5LFueFh-f3k8MIJMv3M46_Aaf65YwCbjtSBQ-U,26393
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
- xax/utils/experiments.py,sha256=X6MESZ3z_Z0DLH6NQucuPzibuOc6rZmlf5UZt4in458,29591
47
+ xax/utils/experiments.py,sha256=vm_hWfaty_wEHVdoU2ALiBiGJze7IoDJIfXi6pd_a9I,29360
48
48
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
49
49
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
50
50
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.1.15.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.1.15.dist-info/METADATA,sha256=i5thFSTL1Zx03UpnCj7f71rxSgs0P3L6ZDd6vYEtM7U,1878
63
- xax-0.1.15.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.1.15.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.1.15.dist-info/RECORD,,
61
+ xax-0.1.16.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.1.16.dist-info/METADATA,sha256=gfh7iFi7Wz3fJDf2w1KKs8H0uanhn2HFsR67TvP6uZM,1878
63
+ xax-0.1.16.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.1.16.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.1.16.dist-info/RECORD,,
File without changes