xax 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.14"
15
+ __version__ = "0.1.16"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -40,6 +40,7 @@ __all__ = [
40
40
  "load_eqx_mlp",
41
41
  "make_eqx_mlp",
42
42
  "save_eqx",
43
+ "cubic_bezier_interpolation",
43
44
  "euler_to_quat",
44
45
  "get_projected_gravity_vector_from_quat",
45
46
  "quat_to_euler",
@@ -201,6 +202,7 @@ NAME_MAP: dict[str, str] = {
201
202
  "load_eqx_mlp": "nn.equinox",
202
203
  "make_eqx_mlp": "nn.equinox",
203
204
  "save_eqx": "nn.equinox",
205
+ "cubic_bezier_interpolation": "nn.geom",
204
206
  "euler_to_quat": "nn.geom",
205
207
  "get_projected_gravity_vector_from_quat": "nn.geom",
206
208
  "quat_to_euler": "nn.geom",
@@ -363,6 +365,7 @@ if IMPORT_ALL or TYPE_CHECKING:
363
365
  save_eqx,
364
366
  )
365
367
  from xax.nn.geom import (
368
+ cubic_bezier_interpolation,
366
369
  euler_to_quat,
367
370
  get_projected_gravity_vector_from_quat,
368
371
  quat_to_euler,
xax/nn/geom.py CHANGED
@@ -1,10 +1,10 @@
1
1
  """Defines geometry functions."""
2
2
 
3
- import jax
4
3
  from jax import numpy as jnp
4
+ from jaxtyping import Array
5
5
 
6
6
 
7
- def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
7
+ def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
8
8
  """Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
9
9
 
10
10
  Args:
@@ -40,7 +40,7 @@ def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
40
40
  return jnp.concatenate([roll, pitch, yaw], axis=-1)
41
41
 
42
42
 
43
- def euler_to_quat(euler_3: jax.Array) -> jax.Array:
43
+ def euler_to_quat(euler_3: Array) -> Array:
44
44
  """Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
45
45
 
46
46
  Args:
@@ -75,7 +75,7 @@ def euler_to_quat(euler_3: jax.Array) -> jax.Array:
75
75
  return quat
76
76
 
77
77
 
78
- def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -> jax.Array:
78
+ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Array:
79
79
  """Calculates the gravity vector projected onto the local frame given a quaternion orientation.
80
80
 
81
81
  Args:
@@ -101,7 +101,7 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
101
101
  return jnp.concatenate([gx, gy, -gz], axis=-1)
102
102
 
103
103
 
104
- def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
104
+ def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
105
105
  """Rotates a vector by a quaternion.
106
106
 
107
107
  Args:
@@ -156,3 +156,24 @@ def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6)
156
156
  )
157
157
 
158
158
  return jnp.concatenate([xx, yy, zz], axis=-1)
159
+
160
+
161
+ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
162
+ """Cubic bezier interpolation.
163
+
164
+ This is a cubic bezier curve that starts at y_start and ends at y_end,
165
+ and is controlled by the parameter x. The curve is defined by the following formula:
166
+
167
+ y(x) = y_start + (y_end - y_start) * (x**3 + 3 * (x**2 * (1 - x)))
168
+
169
+ Args:
170
+ y_start: The start value, shape (*).
171
+ y_end: The end value, shape (*).
172
+ x: The interpolation parameter, shape (*).
173
+
174
+ Returns:
175
+ The interpolated value, shape (*).
176
+ """
177
+ y_diff = y_end - y_start
178
+ bezier = x**3 + 3 * (x**2 * (1 - x))
179
+ return y_start + y_diff * bezier
xax/task/logger.py CHANGED
@@ -205,6 +205,12 @@ def as_numpy(array: Array) -> np.ndarray:
205
205
  return np.array(array)
206
206
 
207
207
 
208
+ @dataclass(kw_only=True)
209
+ class LogString:
210
+ value: str
211
+ secondary: bool
212
+
213
+
208
214
  @dataclass(kw_only=True)
209
215
  class LogImage:
210
216
  image: PILImage
@@ -223,6 +229,12 @@ class LogVideo:
223
229
  fps: int
224
230
 
225
231
 
232
+ @dataclass(kw_only=True)
233
+ class LogScalar:
234
+ value: Number
235
+ secondary: bool
236
+
237
+
226
238
  @dataclass(kw_only=True)
227
239
  class LogDistribution:
228
240
  mean: Number
@@ -243,10 +255,10 @@ class LogHistogram:
243
255
  @dataclass(kw_only=True)
244
256
  class LogLine:
245
257
  state: State
246
- scalars: dict[str, dict[str, Number]]
258
+ scalars: dict[str, dict[str, LogScalar]]
247
259
  distributions: dict[str, dict[str, LogDistribution]]
248
260
  histograms: dict[str, dict[str, LogHistogram]]
249
- strings: dict[str, dict[str, str]]
261
+ strings: dict[str, dict[str, LogString]]
250
262
  images: dict[str, dict[str, LogImage]]
251
263
  videos: dict[str, dict[str, LogVideo]]
252
264
 
@@ -515,10 +527,10 @@ class Logger:
515
527
  """Defines an intermediate container which holds values to log somewhere else."""
516
528
 
517
529
  def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
518
- self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
530
+ self.scalars: dict[str, dict[str, Callable[[], LogScalar]]] = defaultdict(dict)
519
531
  self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
520
532
  self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
521
- self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
533
+ self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
522
534
  self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
523
535
  self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
524
536
  self.default_namespace = default_namespace
@@ -616,13 +628,23 @@ class Logger:
616
628
  def resolve_namespace(self, namespace: str | None = None) -> str:
617
629
  return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
618
630
 
619
- def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
631
+ def log_scalar(
632
+ self,
633
+ key: str,
634
+ value: Callable[[], Number] | Number,
635
+ *,
636
+ namespace: str | None = None,
637
+ secondary: bool = False,
638
+ ) -> None:
620
639
  """Logs a scalar value.
621
640
 
622
641
  Args:
623
642
  key: The key being logged
624
643
  value: The scalar value being logged
625
644
  namespace: An optional logging namespace
645
+ secondary: If set, treat this as a secondary value (meaning, it is
646
+ less important than other values, and some downstream loggers
647
+ will not display it)
626
648
  """
627
649
  if not self.active:
628
650
  raise RuntimeError("The logger is not active")
@@ -632,11 +654,11 @@ class Logger:
632
654
  assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
633
655
 
634
656
  @functools.lru_cache(maxsize=None)
635
- def scalar_future() -> Number:
657
+ def scalar_future() -> LogScalar:
636
658
  with ContextTimer() as timer:
637
659
  value_concrete = value() if callable(value) else value
638
660
  logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
639
- return value_concrete
661
+ return LogScalar(value=value_concrete, secondary=secondary)
640
662
 
641
663
  self.scalars[namespace][key] = scalar_future
642
664
 
@@ -770,21 +792,31 @@ class Logger:
770
792
 
771
793
  self.histograms[namespace][key] = histogram_future
772
794
 
773
- def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
795
+ def log_string(
796
+ self,
797
+ key: str,
798
+ value: Callable[[], str] | str,
799
+ *,
800
+ namespace: str | None = None,
801
+ secondary: bool = False,
802
+ ) -> None:
774
803
  """Logs a string value.
775
804
 
776
805
  Args:
777
806
  key: The key being logged
778
807
  value: The string value being logged
779
808
  namespace: An optional logging namespace
809
+ secondary: If set, treat this as a secondary value (meaning, it is
810
+ less important than other values, and some downstream loggers
811
+ will not display it)
780
812
  """
781
813
  if not self.active:
782
814
  raise RuntimeError("The logger is not active")
783
815
  namespace = self.resolve_namespace(namespace)
784
816
 
785
817
  @functools.lru_cache(maxsize=None)
786
- def value_future() -> str:
787
- return value() if callable(value) else value
818
+ def value_future() -> LogString:
819
+ return LogString(value=value() if callable(value) else value, secondary=secondary)
788
820
 
789
821
  self.strings[namespace][key] = value_future
790
822
 
xax/task/loggers/json.py CHANGED
@@ -3,11 +3,19 @@
3
3
  import json
4
4
  import sys
5
5
  from dataclasses import asdict
6
- from typing import Any, Literal, TextIO
6
+ from typing import Any, Literal, Mapping, TextIO
7
7
 
8
8
  from jaxtyping import Array
9
9
 
10
- from xax.task.logger import LogError, LoggerImpl, LogLine, LogPing, LogStatus
10
+ from xax.task.logger import (
11
+ LogError,
12
+ LoggerImpl,
13
+ LogLine,
14
+ LogPing,
15
+ LogScalar,
16
+ LogStatus,
17
+ LogString,
18
+ )
11
19
 
12
20
 
13
21
  def get_json_value(value: Any) -> Any: # noqa: ANN401
@@ -61,14 +69,14 @@ class JsonLogger(LoggerImpl):
61
69
  def get_json(self, line: LogLine) -> str:
62
70
  data: dict = {"state": asdict(line.state)}
63
71
 
64
- def add_logs(log: dict[str, dict[str, Any]], data: dict) -> None:
72
+ def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
65
73
  for namespace, values in log.items():
66
74
  if self.remove_unicode_from_namespaces:
67
75
  namespace = namespace.encode("ascii", errors="ignore").decode("ascii").strip()
68
76
  if namespace not in data:
69
77
  data[namespace] = {}
70
78
  for k, v in values.items():
71
- data[namespace][k] = get_json_value(v)
79
+ data[namespace][k] = get_json_value(v.value)
72
80
 
73
81
  add_logs(line.scalars, data)
74
82
  add_logs(line.strings, data)
@@ -4,11 +4,20 @@ import datetime
4
4
  import logging
5
5
  import sys
6
6
  from collections import deque
7
- from typing import Any, Deque, TextIO
7
+ from typing import Any, Deque, Mapping, TextIO
8
8
 
9
9
  from jaxtyping import Array
10
10
 
11
- from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
11
+ from xax.task.logger import (
12
+ LogError,
13
+ LogErrorSummary,
14
+ LoggerImpl,
15
+ LogLine,
16
+ LogPing,
17
+ LogScalar,
18
+ LogStatus,
19
+ LogString,
20
+ )
12
21
  from xax.utils.text import Color, colored, format_timedelta
13
22
 
14
23
 
@@ -95,20 +104,17 @@ class StdoutLogger(LoggerImpl):
95
104
  def write_log_window(self, line: LogLine) -> None:
96
105
  namespace_to_lines: dict[str, dict[str, str]] = {}
97
106
 
98
- def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
107
+ def add_logs(
108
+ log: Mapping[str, Mapping[str, LogScalar | LogString]],
109
+ namespace_to_lines: dict[str, dict[str, str]],
110
+ ) -> None:
99
111
  for namespace, values in log.items():
100
- if not self.log_timers and namespace.startswith("⌛"):
101
- continue
102
- if not self.log_perf and namespace.startswith("🔧"):
103
- continue
104
- if not self.log_optim and namespace.startswith("📉"):
105
- continue
106
- if not self.log_fp and namespace.startswith("⚖️"):
107
- continue
108
- if namespace not in namespace_to_lines:
109
- namespace_to_lines[namespace] = {}
110
112
  for k, v in values.items():
111
- v_str = as_str(v, self.precision)
113
+ if v.secondary:
114
+ continue
115
+ if namespace not in namespace_to_lines:
116
+ namespace_to_lines[namespace] = {}
117
+ v_str = as_str(v.value, self.precision)
112
118
  namespace_to_lines[namespace][k] = v_str
113
119
 
114
120
  add_logs(line.scalars, namespace_to_lines)
@@ -116,9 +122,8 @@ class StdoutLogger(LoggerImpl):
116
122
  if not namespace_to_lines:
117
123
  return
118
124
 
119
- self.write_fp.write("\n")
120
125
  for namespace, lines in sorted(namespace_to_lines.items()):
121
- self.write_fp.write(f"{colored(namespace, 'cyan', bold=True)}\n")
126
+ self.write_fp.write(f"\n{colored(namespace, 'cyan', bold=True)}\n")
122
127
  for k, v in lines.items():
123
128
  self.write_fp.write(f" ↪ {k}: {v}\n")
124
129
 
@@ -158,7 +158,7 @@ class TensorboardLogger(LoggerImpl):
158
158
  for scalar_key, scalar_value in scalars.items():
159
159
  writer.add_scalar(
160
160
  f"{namespace}/{scalar_key}",
161
- as_float(scalar_value),
161
+ as_float(scalar_value.value),
162
162
  global_step=line.state.num_steps,
163
163
  walltime=walltime,
164
164
  )
@@ -192,7 +192,7 @@ class TensorboardLogger(LoggerImpl):
192
192
  for string_key, string_value in strings.items():
193
193
  writer.add_text(
194
194
  f"{namespace}/{string_key}",
195
- string_value,
195
+ string_value.value,
196
196
  global_step=line.state.num_steps,
197
197
  walltime=walltime,
198
198
  )
@@ -248,15 +248,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
248
248
  stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
249
249
 
250
250
  if stats is not None:
251
- self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
252
- self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
253
- self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
254
- self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
255
- self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
256
- self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
257
- self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
258
- self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
259
- self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
260
- self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
251
+ self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu", secondary=True)
252
+ self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu", secondary=True)
253
+ self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu", secondary=True)
254
+ self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem", secondary=True)
255
+ self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem", secondary=True)
256
+ self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem", secondary=True)
257
+ self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem", secondary=True)
258
+ self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem", secondary=True)
259
+ self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem", secondary=True)
260
+ self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem", secondary=True)
261
261
 
262
262
  return state
@@ -264,8 +264,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
264
264
  for gpu_stat in stats.values():
265
265
  if gpu_stat is None:
266
266
  continue
267
- self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
268
- self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
269
- self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
267
+ self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu", secondary=True)
268
+ self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu", secondary=True)
269
+ self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu", secondary=True)
270
270
 
271
271
  return state
xax/task/mixins/train.py CHANGED
@@ -50,8 +50,7 @@ from xax.utils.experiments import (
50
50
  TrainingFinishedError,
51
51
  diff_configs,
52
52
  get_diff_string,
53
- get_git_state,
54
- get_packages_with_versions,
53
+ get_state_file_string,
55
54
  get_training_code,
56
55
  )
57
56
  from xax.utils.jax import jit as xax_jit
@@ -219,7 +218,12 @@ class TrainMixin(
219
218
  return state.replace(elapsed_time_s=time.time() - state.start_time_s)
220
219
 
221
220
  def log_train_step(
222
- self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
221
+ self,
222
+ model: PyTree,
223
+ batch: Batch,
224
+ output: Output,
225
+ metrics: FrozenDict[str, Array],
226
+ state: State,
223
227
  ) -> None:
224
228
  """Override this function to do logging during the training phase.
225
229
 
@@ -235,7 +239,12 @@ class TrainMixin(
235
239
  """
236
240
 
237
241
  def log_valid_step(
238
- self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
242
+ self,
243
+ model: PyTree,
244
+ batch: Batch,
245
+ output: Output,
246
+ metrics: FrozenDict[str, Array],
247
+ state: State,
239
248
  ) -> None:
240
249
  """Override this function to do logging during the validation phase.
241
250
 
@@ -253,12 +262,20 @@ class TrainMixin(
253
262
  def log_state_timers(self, state: State) -> None:
254
263
  timer = self.state_timers[state.phase]
255
264
  timer.step(state)
256
- for ns, d in timer.log_dict().items():
257
- for k, v in d.items():
258
- self.logger.log_scalar(k, v, namespace=ns)
265
+ for k, v in timer.log_dict().items():
266
+ if isinstance(v, tuple):
267
+ v, secondary = v
268
+ else:
269
+ secondary = False
270
+ self.logger.log_scalar(k, v, namespace="⌛ timers", secondary=secondary)
259
271
 
260
272
  def log_step(
261
- self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
273
+ self,
274
+ model: PyTree,
275
+ batch: Batch,
276
+ output: Output,
277
+ metrics: FrozenDict[str, Array],
278
+ state: State,
262
279
  ) -> None:
263
280
  phase = state.phase
264
281
 
@@ -534,9 +551,8 @@ class TrainMixin(
534
551
  logger.log(LOG_STATUS, self.task_path)
535
552
  logger.log(LOG_STATUS, self.task_name)
536
553
  logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
537
- self.logger.log_file("git_state.txt", get_git_state(self))
538
- self.logger.log_file("packages.txt", get_packages_with_versions())
539
- self.logger.log_file("training_code.txt", get_training_code(self))
554
+ self.logger.log_file("state.txt", get_state_file_string(self))
555
+ self.logger.log_file("training_code.py", get_training_code(self))
540
556
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
541
557
 
542
558
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
xax/utils/experiments.py CHANGED
@@ -114,28 +114,15 @@ class StateTimer:
114
114
  self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
115
115
  self.iter_timer.step(cur_time)
116
116
 
117
- def log_dict(self) -> dict[str, dict[str, int | float]]:
118
- logs: dict[str, dict[str, int | float]] = {}
119
-
120
- # Logs step statistics.
121
- logs[" steps"] = {
122
- "total": self.step_timer.steps,
123
- "per-second": self.step_timer.steps_per_second,
124
- }
125
-
126
- # Logs sample statistics.
127
- logs["⌛ samples"] = {
128
- "total": self.sample_timer.steps,
129
- "per-second": self.sample_timer.steps_per_second,
130
- }
131
-
132
- # Logs full iteration statistics.
133
- logs["⌛ dt"] = {
134
- "iter": self.iter_timer.iter_seconds,
117
+ def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
118
+ return {
119
+ "steps": (self.step_timer.steps, True),
120
+ "steps/second": self.step_timer.steps_per_second,
121
+ "samples": (self.sample_timer.steps, True),
122
+ "samples/second": (self.sample_timer.steps_per_second, True),
123
+ "dt": self.iter_timer.iter_seconds,
135
124
  }
136
125
 
137
- return logs
138
-
139
126
 
140
127
  class IntervalTicker:
141
128
  def __init__(self, interval: float) -> None:
@@ -479,6 +466,20 @@ def get_packages_with_versions() -> str:
479
466
  return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
480
467
 
481
468
 
469
+ def get_command_line_string() -> str:
470
+ return " ".join(sys.argv)
471
+
472
+
473
+ def get_state_file_string(obj: object) -> str:
474
+ return "\n\n".join(
475
+ [
476
+ f"=== Command Line ===\n\n{get_command_line_string()}",
477
+ f"=== Git State ===\n\n{get_git_state(obj)}",
478
+ f"=== Packages ===\n\n{get_packages_with_versions()}",
479
+ ]
480
+ )
481
+
482
+
482
483
  def get_training_code(obj: object) -> str:
483
484
  """Gets the text from the file containing the provided object.
484
485
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.14
3
+ Version: 0.1.16
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=D7czvfKKQJlemPuatMPVYbAO4ST3U272QRIyTOru7JI,13989
1
+ xax/__init__.py,sha256=pSWV5RtPBJynHr7dCqscbnMkETZPUyw8D6MHK4CuS90,14104
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -10,14 +10,14 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
11
  xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
- xax/nn/geom.py,sha256=Bj9Z4Y-uoNQuaA_eB_MyG7yImZLuOq8KCLUj1l3daoc,4545
13
+ xax/nn/geom.py,sha256=PN0Ndn575aVtsSfxi67RghHB7luRkqtpS7bPbT1LpLE,5201
14
14
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
15
  xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
16
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
17
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
20
- xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
20
+ xax/task/logger.py,sha256=Upx7cCZvaVIs75CHTfIzYmsuaFRsGu0FvziTZuazT4k,37083
21
21
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
22
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
23
23
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -26,25 +26,25 @@ xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,140
26
26
  xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
27
27
  xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
29
- xax/task/loggers/json.py,sha256=yXHb1bmfsEnk4p0F1Up1ertWYdcPAFZm25NT8wE3Jb8,4045
29
+ xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
30
30
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
31
- xax/task/loggers/stdout.py,sha256=BBXqr95gNt5KuCN8XyKnTJF8JdwkR4JgLKrkvcaTBVM,6788
32
- xax/task/loggers/tensorboard.py,sha256=kI8LvBuBBhPgkP8TeaTQb9SQ0FqaIodwQh2SuWDCnIA,7706
31
+ xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,6564
32
+ xax/task/loggers/tensorboard.py,sha256=HjR-wiCWe0z3nivRzxEZIltzSzka1828bwxWVmMU5Sk,7718
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
35
  xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
36
36
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
- xax/task/mixins/cpu_stats.py,sha256=C_t71UTrv4LwQzhO5iubsfomj4jYa9bzpE4zBcHdoHM,9211
37
+ xax/task/mixins/cpu_stats.py,sha256=vAjEc3HpPnl56m7vshYX0dXAHJrB98DzVdsYSRqQllc,9371
38
38
  xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
39
- xax/task/mixins/gpu_stats.py,sha256=IGPBro9xzSivwD43zM18lWcuei7IhA8LilxSPHqNl4I,8747
39
+ xax/task/mixins/gpu_stats.py,sha256=4HU6teEDlqMitLbSx7fbyL4qBJ0PgGy0Ly_Pzife8yo,8795
40
40
  xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
41
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
44
+ xax/task/mixins/train.py,sha256=4Xr8b5LFueFh-f3k8MIJMv3M46_Aaf65YwCbjtSBQ-U,26393
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
- xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
47
+ xax/utils/experiments.py,sha256=vm_hWfaty_wEHVdoU2ALiBiGJze7IoDJIfXi6pd_a9I,29360
48
48
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
49
49
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
50
50
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.1.14.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.1.14.dist-info/METADATA,sha256=WbKtAXJUYKHvBrOJPEm_eXF9O9ekc0WdPmsQQCSGG5Q,1878
63
- xax-0.1.14.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.1.14.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.1.14.dist-info/RECORD,,
61
+ xax-0.1.16.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.1.16.dist-info/METADATA,sha256=gfh7iFi7Wz3fJDf2w1KKs8H0uanhn2HFsR67TvP6uZM,1878
63
+ xax-0.1.16.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.1.16.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.1.16.dist-info/RECORD,,
File without changes