xax 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.5"
15
+ __version__ = "0.2.7"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -66,11 +66,13 @@ __all__ = [
66
66
  "StateLogger",
67
67
  "StdoutLogger",
68
68
  "TensorboardLogger",
69
+ "load_ckpt",
69
70
  "CPUStatsOptions",
70
71
  "DataloaderConfig",
71
72
  "GPUStatsOptions",
72
73
  "StepContext",
73
74
  "ValidStepTimer",
75
+ "get_param_count",
74
76
  "Script",
75
77
  "ScriptConfig",
76
78
  "Config",
@@ -230,11 +232,13 @@ NAME_MAP: dict[str, str] = {
230
232
  "StateLogger": "task.loggers.state",
231
233
  "StdoutLogger": "task.loggers.stdout",
232
234
  "TensorboardLogger": "task.loggers.tensorboard",
235
+ "load_ckpt": "task.mixins.checkpointing",
233
236
  "CPUStatsOptions": "task.mixins.cpu_stats",
234
237
  "DataloaderConfig": "task.mixins.data_loader",
235
238
  "GPUStatsOptions": "task.mixins.gpu_stats",
236
239
  "StepContext": "task.mixins.step_wrapper",
237
240
  "ValidStepTimer": "task.mixins.train",
241
+ "get_param_count": "task.mixins.train",
238
242
  "Script": "task.script",
239
243
  "ScriptConfig": "task.script",
240
244
  "Config": "task.task",
@@ -390,11 +394,12 @@ if IMPORT_ALL or TYPE_CHECKING:
390
394
  from xax.task.loggers.state import StateLogger
391
395
  from xax.task.loggers.stdout import StdoutLogger
392
396
  from xax.task.loggers.tensorboard import TensorboardLogger
397
+ from xax.task.mixins.checkpointing import load_ckpt
393
398
  from xax.task.mixins.cpu_stats import CPUStatsOptions
394
399
  from xax.task.mixins.data_loader import DataloaderConfig
395
400
  from xax.task.mixins.gpu_stats import GPUStatsOptions
396
401
  from xax.task.mixins.step_wrapper import StepContext
397
- from xax.task.mixins.train import Batch, Output, ValidStepTimer
402
+ from xax.task.mixins.train import Batch, Output, ValidStepTimer, get_param_count
398
403
  from xax.task.script import Script, ScriptConfig
399
404
  from xax.task.task import Config, Task
400
405
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
xax/nn/functions.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # mypy: disable-error-code="override"
2
- """Defines helper Torch functions."""
2
+ """Defines helper Jax functions."""
3
3
 
4
4
  import random
5
5
  from dataclasses import is_dataclass
xax/task/logger.py CHANGED
@@ -521,7 +521,8 @@ class LoggerImpl(ABC):
521
521
  Returns:
522
522
  If the logger should log the current step.
523
523
  """
524
- return self.tickers[state.phase].tick(state.elapsed_time_s.item())
524
+ elapsed_time = state.elapsed_time_s.item() if state.phase == "train" else state.valid_elapsed_time_s.item()
525
+ return self.tickers[state.phase].tick(elapsed_time)
525
526
 
526
527
 
527
528
  class ToastHandler(logging.Handler):
xax/task/loggers/json.py CHANGED
@@ -2,7 +2,6 @@
2
2
 
3
3
  import json
4
4
  import sys
5
- from dataclasses import asdict
6
5
  from typing import Any, Literal, Mapping, TextIO
7
6
 
8
7
  from jaxtyping import Array
@@ -67,7 +66,7 @@ class JsonLogger(LoggerImpl):
67
66
  return self.err_log_stream
68
67
 
69
68
  def get_json(self, line: LogLine) -> str:
70
- data: dict = {"state": asdict(line.state)}
69
+ data: dict = {"state": line.state.to_dict()}
71
70
 
72
71
  def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
73
72
  for namespace, values in log.items():
@@ -52,6 +52,114 @@ class CheckpointingConfig(ArtifactsConfig):
52
52
  Config = TypeVar("Config", bound=CheckpointingConfig)
53
53
 
54
54
 
55
+ @overload
56
+ def load_ckpt(
57
+ path: Path,
58
+ *,
59
+ part: Literal["all"],
60
+ model_template: PyTree,
61
+ optimizer_template: PyTree,
62
+ opt_state_template: PyTree,
63
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
64
+
65
+
66
+ @overload
67
+ def load_ckpt(
68
+ path: Path,
69
+ *,
70
+ part: Literal["model_state_config"],
71
+ model_template: PyTree,
72
+ ) -> tuple[PyTree, State, DictConfig]: ...
73
+
74
+
75
+ @overload
76
+ def load_ckpt(path: Path, *, part: Literal["model"], model_template: PyTree) -> PyTree: ...
77
+
78
+
79
+ @overload
80
+ def load_ckpt(path: Path, *, part: Literal["opt"], optimizer_template: PyTree) -> optax.GradientTransformation: ...
81
+
82
+
83
+ @overload
84
+ def load_ckpt(path: Path, *, part: Literal["opt_state"], opt_state_template: PyTree) -> optax.OptState: ...
85
+
86
+
87
+ @overload
88
+ def load_ckpt(path: Path, *, part: Literal["state"]) -> State: ...
89
+
90
+
91
+ @overload
92
+ def load_ckpt(path: Path, *, part: Literal["config"]) -> DictConfig: ...
93
+
94
+
95
+ def load_ckpt(
96
+ path: str | Path,
97
+ *,
98
+ part: CheckpointPart = "model",
99
+ model_template: PyTree | None = None,
100
+ optimizer_template: PyTree | None = None,
101
+ opt_state_template: PyTree | None = None,
102
+ ) -> (
103
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
104
+ | tuple[PyTree, State, DictConfig]
105
+ | PyTree
106
+ | optax.GradientTransformation
107
+ | optax.OptState
108
+ | State
109
+ | DictConfig
110
+ ):
111
+ with tarfile.open(path, "r:gz") as tar:
112
+
113
+ def get_model() -> PyTree:
114
+ if model_template is None:
115
+ raise ValueError("model_template must be provided to load model weights")
116
+ if (model := tar.extractfile("model")) is None:
117
+ raise ValueError(f"Checkpoint does not contain a model file: {path}")
118
+ return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
119
+
120
+ def get_opt() -> optax.GradientTransformation:
121
+ if optimizer_template is None:
122
+ raise ValueError("optimizer_template must be provided to load optimizer")
123
+ if (opt := tar.extractfile("optimizer")) is None:
124
+ raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
125
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
126
+
127
+ def get_opt_state() -> optax.OptState:
128
+ if opt_state_template is None:
129
+ raise ValueError("opt_state_template must be provided to load optimizer state")
130
+ if (opt_state := tar.extractfile("opt_state")) is None:
131
+ raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
132
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
133
+
134
+ def get_state() -> State:
135
+ if (state := tar.extractfile("state")) is None:
136
+ raise ValueError(f"Checkpoint does not contain a state file: {path}")
137
+ return State.from_dict(**json.loads(state.read().decode()))
138
+
139
+ def get_config() -> DictConfig:
140
+ if (config := tar.extractfile("config")) is None:
141
+ raise ValueError(f"Checkpoint does not contain a config file: {path}")
142
+ return cast(DictConfig, OmegaConf.load(config))
143
+
144
+ match part:
145
+ case "model":
146
+ return get_model()
147
+ case "opt":
148
+ return get_opt()
149
+ case "opt_state":
150
+ return get_opt_state()
151
+ case "state":
152
+ return get_state()
153
+ case "config":
154
+ return get_config()
155
+ case "model_state_config":
156
+ return get_model(), get_state(), get_config()
157
+ case "all":
158
+ return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
159
+ case _:
160
+ raise ValueError(f"Invalid checkpoint part: {part}")
161
+
162
+
55
163
  class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
56
164
  def __init__(self, config: Config) -> None:
57
165
  super().__init__(config)
@@ -82,149 +190,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
82
190
  return True
83
191
  return False
84
192
 
85
- @overload
86
- def load_ckpt_with_template(
87
- self,
88
- path: Path,
89
- *,
90
- part: Literal["all"],
91
- model_template: PyTree,
92
- optimizer_template: PyTree,
93
- opt_state_template: PyTree,
94
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
95
-
96
- @overload
97
- def load_ckpt_with_template(
98
- self,
99
- path: Path,
100
- *,
101
- part: Literal["model_state_config"],
102
- model_template: PyTree,
103
- ) -> tuple[PyTree, State, Config]: ...
104
-
105
- @overload
106
- def load_ckpt_with_template(
107
- self,
108
- path: Path,
109
- *,
110
- part: Literal["model"],
111
- model_template: PyTree,
112
- ) -> PyTree: ...
113
-
114
- @overload
115
- def load_ckpt_with_template(
116
- self,
117
- path: Path,
118
- *,
119
- part: Literal["opt"],
120
- optimizer_template: PyTree,
121
- ) -> optax.GradientTransformation: ...
122
-
123
- @overload
124
- def load_ckpt_with_template(
125
- self,
126
- path: Path,
127
- *,
128
- part: Literal["opt_state"],
129
- opt_state_template: PyTree,
130
- ) -> optax.OptState: ...
131
-
132
- @overload
133
- def load_ckpt_with_template(
134
- self,
135
- path: Path,
136
- *,
137
- part: Literal["state"],
138
- ) -> State: ...
139
-
140
- @overload
141
- def load_ckpt_with_template(
142
- self,
143
- path: Path,
144
- *,
145
- part: Literal["config"],
146
- ) -> Config: ...
147
-
148
- def load_ckpt_with_template(
149
- self,
150
- path: Path,
151
- *,
152
- part: CheckpointPart = "all",
153
- model_template: PyTree | None = None,
154
- optimizer_template: PyTree | None = None,
155
- opt_state_template: PyTree | None = None,
156
- ) -> (
157
- tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
158
- | tuple[PyTree, State, Config]
159
- | PyTree
160
- | optax.GradientTransformation
161
- | optax.OptState
162
- | State
163
- | Config
164
- ):
165
- """Load a checkpoint.
166
-
167
- Args:
168
- path: Path to the checkpoint directory
169
- part: Which part of the checkpoint to load
170
- model_template: Template model with correct structure but uninitialized weights
171
- optimizer_template: Template optimizer with correct structure but uninitialized weights
172
- opt_state_template: Template optimizer state with correct structure but uninitialized weights
173
-
174
- Returns:
175
- The requested checkpoint components
176
- """
177
- with tarfile.open(path, "r:gz") as tar:
178
-
179
- def get_model() -> PyTree:
180
- if model_template is None:
181
- raise ValueError("model_template must be provided to load model weights")
182
- if (model := tar.extractfile("model")) is None:
183
- raise ValueError(f"Checkpoint does not contain a model file: {path}")
184
- return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
185
-
186
- def get_opt() -> optax.GradientTransformation:
187
- if optimizer_template is None:
188
- raise ValueError("optimizer_template must be provided to load optimizer")
189
- if (opt := tar.extractfile("optimizer")) is None:
190
- raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
191
- return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
192
-
193
- def get_opt_state() -> optax.OptState:
194
- if opt_state_template is None:
195
- raise ValueError("opt_state_template must be provided to load optimizer state")
196
- if (opt_state := tar.extractfile("opt_state")) is None:
197
- raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
198
- return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
199
-
200
- def get_state() -> State:
201
- if (state := tar.extractfile("state")) is None:
202
- raise ValueError(f"Checkpoint does not contain a state file: {path}")
203
- return State.from_dict(**json.loads(state.read().decode()))
204
-
205
- def get_config() -> Config:
206
- if (config := tar.extractfile("config")) is None:
207
- raise ValueError(f"Checkpoint does not contain a config file: {path}")
208
- return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
209
-
210
- match part:
211
- case "model":
212
- return get_model()
213
- case "opt":
214
- return get_opt()
215
- case "opt_state":
216
- return get_opt_state()
217
- case "state":
218
- return get_state()
219
- case "config":
220
- return get_config()
221
- case "model_state_config":
222
- return get_model(), get_state(), get_config()
223
- case "all":
224
- return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
225
- case _:
226
- raise ValueError(f"Invalid checkpoint part: {part}")
227
-
228
193
  def save_checkpoint(
229
194
  self,
230
195
  model: PyTree | None = None,
xax/task/mixins/train.py CHANGED
@@ -40,7 +40,7 @@ from xax.core.state import Phase, State
40
40
  from xax.nn.functions import set_random_seed
41
41
  from xax.nn.parallel import is_master
42
42
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
43
- from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
43
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart, load_ckpt
44
44
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
45
45
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
46
46
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -96,6 +96,12 @@ def batches_per_step_schedule(schedule: list[int] | None) -> list[int] | None:
96
96
  return list(itertools.accumulate([0] + schedule))
97
97
 
98
98
 
99
+ def get_param_count(pytree: PyTree) -> int:
100
+ """Calculates the total number of parameters in a PyTree."""
101
+ leaves, _ = jax.tree.flatten(pytree)
102
+ return sum(x.size for x in leaves if isinstance(x, jnp.ndarray))
103
+
104
+
99
105
  class ValidStepTimer:
100
106
  def __init__(
101
107
  self,
@@ -115,19 +121,22 @@ class ValidStepTimer:
115
121
  self.last_valid_time: float | None = None
116
122
  self.last_valid_step: int | None = None
117
123
 
124
+ def _reset(self, state: State) -> None:
125
+ self.last_valid_time = state.elapsed_time_s.item()
126
+ self.last_valid_step = state.num_steps.item()
127
+
118
128
  def is_valid_step(self, state: State) -> bool:
119
129
  if state.num_steps < self.valid_first_n_steps:
120
130
  return True
121
131
 
122
132
  if self.last_valid_time is None or self.last_valid_step is None:
123
- self.last_valid_time = state.elapsed_time_s.item()
124
- self.last_valid_step = state.num_steps.item()
133
+ self._reset(state)
125
134
  return False
126
135
 
127
136
  # Step-based validation.
128
137
  valid_every_n_steps = self.valid_every_n_steps
129
138
  if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
130
- self.last_valid_step = state.num_steps.item()
139
+ self._reset(state)
131
140
  return True
132
141
 
133
142
  # Time-based validation.
@@ -136,14 +145,14 @@ class ValidStepTimer:
136
145
  valid_every_n_seconds is not None
137
146
  and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
138
147
  ):
139
- self.last_valid_time = state.elapsed_time_s.item()
148
+ self._reset(state)
140
149
  return True
141
150
 
142
151
  # Time-based validation for first validation step.
143
152
  if self.first_valid_step_flag:
144
153
  valid_first_n_seconds = self.valid_first_n_seconds
145
154
  if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
146
- self.last_valid_time = state.elapsed_time_s.item()
155
+ self._reset(state)
147
156
  self.first_valid_step_flag = False
148
157
  return True
149
158
 
@@ -357,6 +366,7 @@ class TrainMixin(
357
366
  model = self.get_model(key)
358
367
  state = State.init_state()
359
368
 
369
+ self.log_model_size(model)
360
370
  if not load_optimizer:
361
371
  return model, state
362
372
 
@@ -447,44 +457,43 @@ class TrainMixin(
447
457
  match part:
448
458
  case "model_state_config":
449
459
  model_spec = eqx.filter_eval_shape(self.get_model, key)
450
- return self.load_ckpt_with_template(path, part="model_state_config", model_template=model_spec)
460
+ model, state, config = load_ckpt(path, part="model_state_config", model_template=model_spec)
461
+ config = self.get_config(config, use_cli=False)
462
+ return model, state, config
451
463
 
452
464
  case "model":
453
465
  model_spec = eqx.filter_eval_shape(self.get_model, key)
454
- return self.load_ckpt_with_template(path, part="model", model_template=model_spec)
455
-
456
- case "config":
457
- return self.load_ckpt_with_template(path, part="config")
466
+ return load_ckpt(path, part="model", model_template=model_spec)
458
467
 
459
468
  case "opt":
460
469
  optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
461
- return self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
470
+ return load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
462
471
 
463
472
  case "opt_state":
464
473
  if model is None:
465
474
  model_spec = eqx.filter_eval_shape(self.get_model, key)
466
- model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
475
+ model = load_ckpt(path, part="model", model_template=model_spec)
467
476
  if optimizer is None:
468
477
  optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
469
- optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
478
+ optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
470
479
  opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
471
- return self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
480
+ return load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
472
481
 
473
482
  case "state":
474
- return self.load_ckpt_with_template(path, part="state")
483
+ return load_ckpt(path, part="state")
475
484
 
476
485
  case "config":
477
- return self.load_ckpt_with_template(path, part="config")
486
+ return self.get_config(load_ckpt(path, part="config"), use_cli=False)
478
487
 
479
488
  case "all":
480
489
  model_spec = eqx.filter_eval_shape(self.get_model, key)
481
- model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
490
+ model = load_ckpt(path, part="model", model_template=model_spec)
482
491
  optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
483
- optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
492
+ optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
484
493
  opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
485
- opt_state = self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
486
- state = self.load_ckpt_with_template(path, part="state")
487
- config = self.load_ckpt_with_template(path, part="config")
494
+ opt_state = load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
495
+ state = load_ckpt(path, part="state")
496
+ config = self.get_config(load_ckpt(path, part="config"), use_cli=False)
488
497
  return model, optimizer, opt_state, state, config
489
498
 
490
499
  case _:
@@ -680,6 +689,9 @@ class TrainMixin(
680
689
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
681
690
  self.logger.log_file("info.json", get_info_json())
682
691
 
692
+ def log_model_size(self, model: PyTree) -> None:
693
+ logger.info("Model size: %s", f"{get_param_count(model):,}")
694
+
683
695
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
684
696
  return eqx.is_inexact_array(item)
685
697
 
xax/utils/jaxpr.py CHANGED
@@ -3,10 +3,10 @@
3
3
  from pathlib import Path
4
4
 
5
5
  import jax
6
- import jax.core
6
+ import jax.extend.core
7
7
 
8
8
 
9
- def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
9
+ def save_jaxpr_dot(closed_jaxpr: jax.extend.core.ClosedJaxpr, filename: str | Path) -> None:
10
10
  """Save the JAXPR to a DOT file.
11
11
 
12
12
  Example usage:
@@ -30,15 +30,15 @@ def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) ->
30
30
  with open(filename, "w") as f:
31
31
  f.write("digraph Jaxpr {\n")
32
32
 
33
- var_names: dict[jax.core.Var, str] = {}
33
+ var_names: dict[jax.extend.core.Var, str] = {}
34
34
  var_count = 0
35
35
 
36
- def get_var_name(var: jax.core.Var) -> str:
36
+ def get_var_name(var: jax.extend.core.Var) -> str:
37
37
  """Get a unique name for a variable."""
38
38
  nonlocal var_names, var_count
39
39
 
40
40
  # Handle Literal objects specially since they're not hashable
41
- if isinstance(var, jax.core.Literal):
41
+ if isinstance(var, jax.extend.core.Literal):
42
42
  # Create a name based on the literal value
43
43
  name = f"lit_{var.val}"
44
44
  return name
xax/utils/pytree.py CHANGED
@@ -57,7 +57,7 @@ def pytree_has_nans(pytree: PyTree) -> Array:
57
57
 
58
58
  def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
59
59
  """Update a pytree based on a condition."""
60
- # Tricky, need use tree_map because where expects array leafs.
60
+ # Tricky, need use tree.map because where expects array leafs.
61
61
  return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
62
62
 
63
63
 
@@ -138,7 +138,7 @@ class FrozenDict(Mapping[K, V]):
138
138
 
139
139
  def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
140
140
  if isinstance(x, FrozenDict):
141
- return jax.tree_util.tree_map(lambda y: y, x._dict)
141
+ return jax.tree.map(lambda y: y, x._dict)
142
142
  elif isinstance(x, dict):
143
143
  ys = {}
144
144
  for key, value in x.items():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=X_QqDNJir1wdsfRY1CU1F4mdCQMlMZnyqPtY8MM1ODU,14225
1
+ xax/__init__.py,sha256=94V0RHzCNC-aVXtpv5jtYaXJWytSJYOJCjUR69dIM1g,14428
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -9,7 +9,7 @@ xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
11
  xax/nn/export.py,sha256=pRfM2B4hB2EvljysC6AjtgB_7Cn7JtaP3dhYU2stZtY,5545
12
- xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
12
+ xax/nn/functions.py,sha256=Gig46ZjFxzXl3ImOFvpAO6oG6xrnAZywp1Ez8Gwy0BU,2711
13
13
  xax/nn/geom.py,sha256=rImNlkHWeoNcY7f84nknizJ6uzsrMhbAtKeb2xAWxNY,6215
14
14
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
15
  xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
@@ -17,7 +17,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
17
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  xax/task/base.py,sha256=OnXi2hiKPGwt6ng1dutnoQSiw7lEiWFlC_vx99_JsbQ,7694
20
- xax/task/logger.py,sha256=y4PGfMqKbfvPk8WCzr9MOsgG2X9E61KgeBVOYp-9kOY,40875
20
+ xax/task/logger.py,sha256=gE67AaPCfU_1FpxY3t0yNRrIVgtp8Sax9UyOqFYMtzM,40976
21
21
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
22
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
23
23
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -26,13 +26,13 @@ xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,140
26
26
  xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
27
27
  xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
29
- xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
29
+ xax/task/loggers/json.py,sha256=Ukbo6eAq9mSDmo7AqCVu9OFFjjSjcIsdeKP_1WQbyuw,4110
30
30
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
31
31
  xax/task/loggers/stdout.py,sha256=ERLFrYe61hSSztzyxBRseobHQR72YFYjEd2i_hOeJ20,6595
32
32
  xax/task/loggers/tensorboard.py,sha256=KFlsK0zD2ubDqAXYL4Ds7NQ9F-Ke-PHwfhLOYsGcbw4,8306
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
- xax/task/mixins/checkpointing.py,sha256=8Hi-2G0EA5OFRjgiOutlk7HgkD5b-0GHazOAYxnGytM,11409
35
+ xax/task/mixins/checkpointing.py,sha256=zqospBFnTbGt_iriiduVfXazINPbzWpwmIs91KAniMY,10147
36
36
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
37
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
38
38
  xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
@@ -41,25 +41,25 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
41
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=XcetJ0MppV_RDhgg1M9_d9heEXo-zeN_FS3MyczeBBU,31219
44
+ xax/task/mixins/train.py,sha256=t2KW18S9vpUuvmr3VAeKzSEcEdAD6fbi_fjk-KZ6ssA,31426
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
47
  xax/utils/experiments.py,sha256=d2H63ECtVOKySMUMrQRqq4kcuZpoXqo-L931usDVAhE,29903
48
48
  xax/utils/jax.py,sha256=KQYUHjN6t6JIWa11aRSO3edcsAgTscw_dExxI6kCd9g,6767
49
- xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
49
+ xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
50
50
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
51
51
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
52
52
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
53
- xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
53
+ xax/utils/pytree.py,sha256=Q_5NCr90Wlqw1x4yK-FXfDVZEkNK2figmsUSZ4fnxYk,8754
54
54
  xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
55
55
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
56
56
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
57
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
- xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
59
+ xax/utils/types/frozen_dict.py,sha256=s57XaTo2jgeT4_SzNQEjsYb4XrNZJwm1ca4ZdXSE5TY,4676
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.5.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.5.dist-info/METADATA,sha256=4RBxZF_P0cg-a6QUNS9urvzc4BGGfoedqMrnP0L6Ksk,1879
63
- xax-0.2.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.2.5.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.2.5.dist-info/RECORD,,
61
+ xax-0.2.7.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.7.dist-info/METADATA,sha256=7otxcR5N4nVGyao5-NsrLfyuY1LqX8v2PWMEWB3sVy8,1879
63
+ xax-0.2.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.2.7.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.7.dist-info/RECORD,,
File without changes