xax 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +2 -15
- xax/core/state.py +22 -17
- xax/task/base.py +11 -5
- xax/task/mixins/train.py +114 -52
- xax/task/script.py +3 -0
- xax/utils/pytree.py +11 -11
- {xax-0.1.8.dist-info → xax-0.1.10.dist-info}/METADATA +1 -1
- {xax-0.1.8.dist-info → xax-0.1.10.dist-info}/RECORD +11 -11
- {xax-0.1.8.dist-info → xax-0.1.10.dist-info}/WHEEL +0 -0
- {xax-0.1.8.dist-info → xax-0.1.10.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.8.dist-info → xax-0.1.10.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.10"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -23,7 +23,6 @@ __all__ = [
|
|
23
23
|
"get_run_dir",
|
24
24
|
"load_user_config",
|
25
25
|
"State",
|
26
|
-
"cast_phase",
|
27
26
|
"FourierEmbeddings",
|
28
27
|
"IdentityPositionalEmbeddings",
|
29
28
|
"LearnedPositionalEmbeddings",
|
@@ -41,9 +40,6 @@ __all__ = [
|
|
41
40
|
"load_eqx_mlp",
|
42
41
|
"make_eqx_mlp",
|
43
42
|
"save_eqx",
|
44
|
-
"export",
|
45
|
-
"export_flax",
|
46
|
-
"export_with_params",
|
47
43
|
"euler_to_quat",
|
48
44
|
"get_projected_gravity_vector_from_quat",
|
49
45
|
"quat_to_euler",
|
@@ -180,7 +176,6 @@ NAME_MAP: dict[str, str] = {
|
|
180
176
|
"get_run_dir": "core.conf",
|
181
177
|
"load_user_config": "core.conf",
|
182
178
|
"State": "core.state",
|
183
|
-
"cast_phase": "core.state",
|
184
179
|
"FourierEmbeddings": "nn.embeddings",
|
185
180
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
186
181
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -198,9 +193,6 @@ NAME_MAP: dict[str, str] = {
|
|
198
193
|
"load_eqx_mlp": "nn.equinox",
|
199
194
|
"make_eqx_mlp": "nn.equinox",
|
200
195
|
"save_eqx": "nn.equinox",
|
201
|
-
"export": "nn.export",
|
202
|
-
"export_flax": "nn.export",
|
203
|
-
"export_with_params": "nn.export",
|
204
196
|
"euler_to_quat": "nn.geom",
|
205
197
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
206
198
|
"quat_to_euler": "nn.geom",
|
@@ -329,7 +321,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
329
321
|
get_run_dir,
|
330
322
|
load_user_config,
|
331
323
|
)
|
332
|
-
from xax.core.state import Phase, State
|
324
|
+
from xax.core.state import Phase, State
|
333
325
|
from xax.nn.embeddings import (
|
334
326
|
EmbeddingKind,
|
335
327
|
FourierEmbeddings,
|
@@ -354,11 +346,6 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
354
346
|
make_eqx_mlp,
|
355
347
|
save_eqx,
|
356
348
|
)
|
357
|
-
from xax.nn.export import (
|
358
|
-
export,
|
359
|
-
export_flax,
|
360
|
-
export_with_params,
|
361
|
-
)
|
362
349
|
from xax.nn.geom import (
|
363
350
|
euler_to_quat,
|
364
351
|
get_projected_gravity_vector_from_quat,
|
xax/core/state.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
"""Defines a dataclass for keeping track of the current training state."""
|
2
2
|
|
3
3
|
import time
|
4
|
-
from dataclasses import dataclass
|
5
|
-
from typing import Literal, NotRequired, TypedDict,
|
4
|
+
from dataclasses import asdict, dataclass
|
5
|
+
from typing import Any, Literal, NotRequired, TypedDict, Unpack, cast
|
6
6
|
|
7
|
+
import jax
|
7
8
|
from omegaconf import MISSING
|
8
9
|
|
9
10
|
from xax.core.conf import field
|
@@ -11,12 +12,6 @@ from xax.core.conf import field
|
|
11
12
|
Phase = Literal["train", "valid"]
|
12
13
|
|
13
14
|
|
14
|
-
def cast_phase(raw_phase: str) -> Phase:
|
15
|
-
args = get_args(Phase)
|
16
|
-
assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}"
|
17
|
-
return cast(Phase, raw_phase)
|
18
|
-
|
19
|
-
|
20
15
|
class StateDict(TypedDict, total=False):
|
21
16
|
num_steps: NotRequired[int]
|
22
17
|
num_samples: NotRequired[int]
|
@@ -24,10 +19,11 @@ class StateDict(TypedDict, total=False):
|
|
24
19
|
num_valid_samples: NotRequired[int]
|
25
20
|
start_time_s: NotRequired[float]
|
26
21
|
elapsed_time_s: NotRequired[float]
|
27
|
-
|
22
|
+
phase: NotRequired[Phase]
|
28
23
|
|
29
24
|
|
30
|
-
@
|
25
|
+
@jax.tree_util.register_dataclass
|
26
|
+
@dataclass(frozen=True, kw_only=True)
|
31
27
|
class State:
|
32
28
|
num_steps: int = field(MISSING, help="Number of steps so far")
|
33
29
|
num_samples: int = field(MISSING, help="Number of sample so far")
|
@@ -35,15 +31,11 @@ class State:
|
|
35
31
|
num_valid_samples: int = field(MISSING, help="Number of validation samples so far")
|
36
32
|
start_time_s: float = field(MISSING, help="Start time of training")
|
37
33
|
elapsed_time_s: float = field(MISSING, help="Total elapsed time so far")
|
38
|
-
|
34
|
+
_phase: int = field(MISSING, help="Current training phase")
|
39
35
|
|
40
36
|
@property
|
41
37
|
def phase(self) -> Phase:
|
42
|
-
return
|
43
|
-
|
44
|
-
@phase.setter
|
45
|
-
def phase(self, phase: Phase) -> None:
|
46
|
-
self.raw_phase = phase
|
38
|
+
return cast(Phase, ["train", "valid"][self._phase])
|
47
39
|
|
48
40
|
@classmethod
|
49
41
|
def init_state(cls) -> "State":
|
@@ -54,7 +46,7 @@ class State:
|
|
54
46
|
num_valid_samples=0,
|
55
47
|
start_time_s=time.time(),
|
56
48
|
elapsed_time_s=0.0,
|
57
|
-
|
49
|
+
_phase=0,
|
58
50
|
)
|
59
51
|
|
60
52
|
@property
|
@@ -69,3 +61,16 @@ class State:
|
|
69
61
|
return self.num_valid_steps
|
70
62
|
case _:
|
71
63
|
raise ValueError(f"Invalid phase: {phase}")
|
64
|
+
|
65
|
+
def replace(self, **kwargs: Unpack[StateDict]) -> "State":
|
66
|
+
extra_kwargs: dict[str, Any] = {} # noqa: ANN401
|
67
|
+
if "phase" in kwargs:
|
68
|
+
phase = kwargs.pop("phase")
|
69
|
+
match phase:
|
70
|
+
case "train":
|
71
|
+
extra_kwargs["_phase"] = 0
|
72
|
+
case "valid":
|
73
|
+
extra_kwargs["_phase"] = 1
|
74
|
+
case _:
|
75
|
+
raise ValueError(f"Invalid phase: {phase}")
|
76
|
+
return State(**{**asdict(self), **kwargs, **extra_kwargs})
|
xax/task/base.py
CHANGED
@@ -16,7 +16,8 @@ from types import TracebackType
|
|
16
16
|
from typing import Generic, Self, TypeVar, cast
|
17
17
|
|
18
18
|
import jax
|
19
|
-
from omegaconf import
|
19
|
+
from omegaconf import DictConfig, OmegaConf
|
20
|
+
from omegaconf.base import SCMode
|
20
21
|
|
21
22
|
from xax.core.state import State
|
22
23
|
from xax.utils.text import camelcase_to_snakecase
|
@@ -66,9 +67,6 @@ class BaseTask(Generic[Config]):
|
|
66
67
|
|
67
68
|
self.config = config
|
68
69
|
|
69
|
-
if isinstance(self.config, Container):
|
70
|
-
OmegaConf.resolve(self.config)
|
71
|
-
|
72
70
|
def on_step_start(self, state: State) -> State:
|
73
71
|
return state
|
74
72
|
|
@@ -195,7 +193,15 @@ class BaseTask(Generic[Config]):
|
|
195
193
|
cfg = OmegaConf.merge(cfg, *(get_config(path, task_path) for path in paths))
|
196
194
|
cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(non_paths))
|
197
195
|
|
198
|
-
return cast(
|
196
|
+
return cast(
|
197
|
+
Config,
|
198
|
+
OmegaConf.to_container(
|
199
|
+
cfg,
|
200
|
+
resolve=True,
|
201
|
+
throw_on_missing=True,
|
202
|
+
structured_config_mode=SCMode.INSTANTIATE,
|
203
|
+
),
|
204
|
+
)
|
199
205
|
|
200
206
|
@classmethod
|
201
207
|
def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
|
xax/task/mixins/train.py
CHANGED
@@ -53,8 +53,10 @@ from xax.utils.experiments import (
|
|
53
53
|
get_packages_with_versions,
|
54
54
|
get_training_code,
|
55
55
|
)
|
56
|
+
from xax.utils.jax import jit as xax_jit
|
56
57
|
from xax.utils.logging import LOG_STATUS
|
57
58
|
from xax.utils.text import highlight_exception_message, show_info
|
59
|
+
from xax.utils.types.frozen_dict import FrozenDict
|
58
60
|
|
59
61
|
logger = logging.getLogger(__name__)
|
60
62
|
|
@@ -212,32 +214,31 @@ class TrainMixin(
|
|
212
214
|
|
213
215
|
def on_step_end(self, state: State) -> State:
|
214
216
|
state = super().on_step_end(state)
|
215
|
-
state.elapsed_time_s
|
216
|
-
return state
|
217
|
+
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
217
218
|
|
218
|
-
def log_train_step(self,
|
219
|
+
def log_train_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
|
219
220
|
"""Override this function to do logging during the training phase.
|
220
221
|
|
221
222
|
This function is called after the model forward pass and before the
|
222
223
|
backward pass. It is called in the training phase.
|
223
224
|
|
224
225
|
Args:
|
225
|
-
model: The current model.
|
226
226
|
batch: The batch from the dataloader.
|
227
227
|
output: The model output.
|
228
|
+
metrics: The metrics for the current batch.
|
228
229
|
state: The current training state.
|
229
230
|
"""
|
230
231
|
|
231
|
-
def log_valid_step(self,
|
232
|
+
def log_valid_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
|
232
233
|
"""Override this function to do logging during the validation phase.
|
233
234
|
|
234
235
|
This function is called after the model forward pass. It is called in
|
235
236
|
the validation phase.
|
236
237
|
|
237
238
|
Args:
|
238
|
-
model: The current model.
|
239
239
|
batch: The batch from the dataloader.
|
240
240
|
output: The model output.
|
241
|
+
metrics: The metrics for the current batch.
|
241
242
|
state: The current training state.
|
242
243
|
"""
|
243
244
|
|
@@ -248,18 +249,23 @@ class TrainMixin(
|
|
248
249
|
for k, v in d.items():
|
249
250
|
self.logger.log_scalar(k, v, namespace=ns)
|
250
251
|
|
251
|
-
def log_step(self,
|
252
|
+
def log_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
|
252
253
|
phase = state.phase
|
253
254
|
|
254
|
-
|
255
|
+
for k, v in metrics.items():
|
256
|
+
if v.size == 1:
|
257
|
+
self.logger.log_scalar(k, v.item())
|
258
|
+
else:
|
259
|
+
self.logger.log_histogram(k, v)
|
260
|
+
|
255
261
|
self.log_state_timers(state)
|
256
262
|
|
257
263
|
# Delegate to the appropriate logging function based on the phase.
|
258
264
|
match phase:
|
259
265
|
case "train":
|
260
|
-
self.log_train_step(
|
266
|
+
self.log_train_step(batch, output, metrics, state)
|
261
267
|
case "valid":
|
262
|
-
self.log_valid_step(
|
268
|
+
self.log_valid_step(batch, output, metrics, state)
|
263
269
|
case _:
|
264
270
|
raise KeyError(f"Unknown phase: {phase}")
|
265
271
|
|
@@ -332,8 +338,7 @@ class TrainMixin(
|
|
332
338
|
|
333
339
|
return model, optimizer, opt_state, state
|
334
340
|
|
335
|
-
|
336
|
-
def get_output(self, model: PyTree, batch: Batch) -> Output:
|
341
|
+
def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
|
337
342
|
"""Gets the output from the model.
|
338
343
|
|
339
344
|
By default, we assume the model is a function that takes the batch as
|
@@ -343,11 +348,11 @@ class TrainMixin(
|
|
343
348
|
Args:
|
344
349
|
model: The current model.
|
345
350
|
batch: The current minibatch of samples.
|
351
|
+
state: The current training state.
|
346
352
|
"""
|
347
353
|
raise NotImplementedError("`get_output` must be implemented by the subclass")
|
348
354
|
|
349
|
-
|
350
|
-
def compute_loss(self, model: PyTree, batch: Batch, output: Output) -> Array:
|
355
|
+
def compute_loss(self, model: PyTree, batch: Batch, output: Output, state: State) -> Array:
|
351
356
|
"""Gets the loss for the current batch.
|
352
357
|
|
353
358
|
By default, we assume the model is a function that takes the batch as
|
@@ -358,6 +363,7 @@ class TrainMixin(
|
|
358
363
|
model: The current model.
|
359
364
|
batch: The current minibatch of samples.
|
360
365
|
output: The output from the model.
|
366
|
+
state: The current training state.
|
361
367
|
|
362
368
|
Returns:
|
363
369
|
The computed loss, as a tensor.
|
@@ -366,24 +372,59 @@ class TrainMixin(
|
|
366
372
|
raise ValueError(f"When model output is not the loss, you must override `compute_loss`. Got {type(output)}")
|
367
373
|
return output
|
368
374
|
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
375
|
+
def compute_metrics(
|
376
|
+
self,
|
377
|
+
model: PyTree,
|
378
|
+
batch: Batch,
|
379
|
+
output: Output,
|
380
|
+
loss: Array,
|
381
|
+
state: State,
|
382
|
+
) -> dict[str, Array]:
|
383
|
+
"""Computes the metrics for the current batch.
|
384
|
+
|
385
|
+
Args:
|
386
|
+
model: The current model.
|
387
|
+
batch: The current minibatch of samples.
|
388
|
+
output: The output from the model.
|
389
|
+
loss: The loss for the current batch.
|
390
|
+
state: The current training state.
|
391
|
+
|
392
|
+
Returns:
|
393
|
+
A dictionary of metrics.
|
394
|
+
"""
|
395
|
+
return {
|
396
|
+
"loss": loss,
|
397
|
+
}
|
398
|
+
|
399
|
+
@xax_jit(static_argnames=["self", "model_static"])
|
400
|
+
def get_output_and_loss(
|
401
|
+
self,
|
402
|
+
model_arr: PyTree,
|
403
|
+
model_static: PyTree,
|
404
|
+
batch: Batch,
|
405
|
+
state: State,
|
406
|
+
) -> tuple[Array, tuple[Output, FrozenDict[str, Array]]]:
|
407
|
+
model = eqx.combine(model_arr, model_static)
|
408
|
+
output = self.get_output(model, batch, state)
|
409
|
+
loss = self.compute_loss(model, batch, output, state)
|
410
|
+
metrics = self.compute_metrics(model, batch, output, loss, state)
|
411
|
+
return loss, (output, FrozenDict(metrics))
|
374
412
|
|
375
|
-
@eqx.filter_jit
|
376
413
|
def update(
|
377
414
|
self,
|
378
|
-
|
415
|
+
model_arr: PyTree,
|
416
|
+
model_static: PyTree,
|
379
417
|
optimizer: optax.GradientTransformation,
|
380
418
|
opt_state: optax.OptState,
|
381
419
|
batch: Batch,
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
420
|
+
state: State,
|
421
|
+
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
422
|
+
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
423
|
+
grad_fn = xax_jit(static_argnums=[1])(grad_fn)
|
424
|
+
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
425
|
+
updates, opt_state = optimizer.update(grads, opt_state, model_arr)
|
426
|
+
model_arr = eqx.apply_updates(model_arr, updates)
|
427
|
+
return model_arr, opt_state, output, metrics
|
387
428
|
|
388
429
|
def get_size_of_batch(self, batch: Batch) -> int | None:
|
389
430
|
"""Gets the batch size for the current batch.
|
@@ -457,21 +498,32 @@ class TrainMixin(
|
|
457
498
|
self.logger.log_file("training_code.txt", get_training_code(self))
|
458
499
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
459
500
|
|
460
|
-
|
501
|
+
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
502
|
+
return eqx.is_inexact_array(item)
|
503
|
+
|
504
|
+
@xax_jit(static_argnames=["self", "model_static", "optimizer"])
|
461
505
|
def train_step(
|
462
506
|
self,
|
463
|
-
|
507
|
+
model_arr: PyTree,
|
508
|
+
model_static: PyTree,
|
464
509
|
optimizer: optax.GradientTransformation,
|
465
510
|
opt_state: optax.OptState,
|
466
511
|
batch: Batch,
|
467
|
-
|
468
|
-
|
469
|
-
|
512
|
+
state: State,
|
513
|
+
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
514
|
+
model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
|
515
|
+
return model_arr, opt_state, output, metrics
|
470
516
|
|
471
|
-
@
|
472
|
-
def val_step(
|
473
|
-
|
474
|
-
|
517
|
+
@xax_jit(static_argnames=["self", "model_static"])
|
518
|
+
def val_step(
|
519
|
+
self,
|
520
|
+
model_arr: PyTree,
|
521
|
+
model_static: PyTree,
|
522
|
+
batch: Batch,
|
523
|
+
state: State,
|
524
|
+
) -> tuple[Output, FrozenDict[str, Array]]:
|
525
|
+
_, (output, metrics) = self.get_output_and_loss(model_arr, model_static, batch, state)
|
526
|
+
return output, metrics
|
475
527
|
|
476
528
|
def train_loop(
|
477
529
|
self,
|
@@ -482,36 +534,46 @@ class TrainMixin(
|
|
482
534
|
valid_pf: Iterator[Batch],
|
483
535
|
state: State,
|
484
536
|
) -> None:
|
537
|
+
model_arr, model_static = eqx.partition(model, self.model_partition_fn)
|
538
|
+
|
485
539
|
while not self.is_training_over(state):
|
486
540
|
if self.valid_step_timer.is_valid_step(state):
|
487
541
|
valid_batch = next(valid_pf)
|
488
|
-
|
489
|
-
|
542
|
+
state = state.replace(
|
543
|
+
phase="valid",
|
544
|
+
num_valid_steps=state.num_valid_steps + 1,
|
545
|
+
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
546
|
+
)
|
490
547
|
|
491
|
-
|
492
|
-
|
493
|
-
state.phase = "valid"
|
494
|
-
self.log_step(model, valid_batch, output, loss, state)
|
495
|
-
state.num_valid_samples += 1
|
548
|
+
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
549
|
+
self.log_step(valid_batch, output, metrics, state)
|
496
550
|
|
497
551
|
state = self.on_step_start(state)
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
552
|
+
train_batch = next(train_pf)
|
553
|
+
state = state.replace(
|
554
|
+
phase="train",
|
555
|
+
num_steps=state.num_steps + 1,
|
556
|
+
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
557
|
+
)
|
558
|
+
|
559
|
+
model_arr, opt_state, output, metrics = self.train_step(
|
560
|
+
model_arr=model_arr,
|
561
|
+
model_static=model_static,
|
562
|
+
optimizer=optimizer,
|
563
|
+
opt_state=opt_state,
|
564
|
+
batch=train_batch,
|
565
|
+
state=state,
|
566
|
+
)
|
567
|
+
self.log_step(train_batch, output, metrics, state)
|
508
568
|
|
509
569
|
state = self.on_step_end(state)
|
510
570
|
|
511
571
|
if self.should_checkpoint(state):
|
572
|
+
model = eqx.combine(model_arr, model_static)
|
512
573
|
self.save_checkpoint(model, optimizer, opt_state, state)
|
513
574
|
|
514
575
|
# After finishing training, save the final checkpoint.
|
576
|
+
model = eqx.combine(model_arr, model_static)
|
515
577
|
self.save_checkpoint(model, optimizer, opt_state, state)
|
516
578
|
|
517
579
|
@contextlib.contextmanager
|
xax/task/script.py
CHANGED
@@ -3,6 +3,8 @@
|
|
3
3
|
from dataclasses import dataclass
|
4
4
|
from typing import Generic, TypeVar
|
5
5
|
|
6
|
+
import jax
|
7
|
+
|
6
8
|
from xax.task.base import BaseConfig, BaseTask
|
7
9
|
from xax.task.mixins import (
|
8
10
|
ArtifactsConfig,
|
@@ -20,6 +22,7 @@ from xax.task.mixins import (
|
|
20
22
|
)
|
21
23
|
|
22
24
|
|
25
|
+
@jax.tree_util.register_dataclass
|
23
26
|
@dataclass(kw_only=True)
|
24
27
|
class ScriptConfig(
|
25
28
|
CPUStatsConfig,
|
xax/utils/pytree.py
CHANGED
@@ -31,7 +31,7 @@ def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
|
31
31
|
|
32
32
|
def slice_pytree(pytree: PyTree, start: Array, slice_length: int) -> PyTree:
|
33
33
|
"""Get a slice of a pytree."""
|
34
|
-
return jax.
|
34
|
+
return jax.tree.map(lambda x: slice_array(x, start, slice_length), pytree)
|
35
35
|
|
36
36
|
|
37
37
|
def flatten_array(x: Array, flatten_size: int) -> Array:
|
@@ -43,14 +43,14 @@ def flatten_array(x: Array, flatten_size: int) -> Array:
|
|
43
43
|
|
44
44
|
def flatten_pytree(pytree: PyTree, flatten_size: int) -> PyTree:
|
45
45
|
"""Flatten a pytree into a (flatten_size, ...) pytree."""
|
46
|
-
return jax.
|
46
|
+
return jax.tree.map(lambda x: flatten_array(x, flatten_size), pytree)
|
47
47
|
|
48
48
|
|
49
49
|
def pytree_has_nans(pytree: PyTree) -> Array:
|
50
50
|
"""Check if a pytree has any NaNs."""
|
51
51
|
has_nans = jax.tree_util.tree_reduce(
|
52
52
|
lambda a, b: jnp.logical_or(a, b),
|
53
|
-
jax.
|
53
|
+
jax.tree.map(lambda x: jnp.any(jnp.isnan(x)), pytree),
|
54
54
|
)
|
55
55
|
return has_nans
|
56
56
|
|
@@ -58,13 +58,13 @@ def pytree_has_nans(pytree: PyTree) -> Array:
|
|
58
58
|
def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
|
59
59
|
"""Update a pytree based on a condition."""
|
60
60
|
# Tricky, need use tree_map because where expects array leafs.
|
61
|
-
return jax.
|
61
|
+
return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
|
62
62
|
|
63
63
|
|
64
64
|
def compute_nan_ratio(pytree: PyTree) -> Array:
|
65
65
|
"""Computes the ratio of NaNs vs non-NaNs in a given PyTree."""
|
66
|
-
nan_counts = jax.
|
67
|
-
total_counts = jax.
|
66
|
+
nan_counts = jax.tree.map(lambda x: jnp.sum(jnp.isnan(x)), pytree)
|
67
|
+
total_counts = jax.tree.map(lambda x: x.size, pytree)
|
68
68
|
|
69
69
|
total_nans = jax.tree_util.tree_reduce(lambda a, b: a + b, nan_counts, 0)
|
70
70
|
total_elements = jax.tree_util.tree_reduce(lambda a, b: a + b, total_counts, 0)
|
@@ -118,7 +118,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
|
|
118
118
|
# Reshape back to the original shape
|
119
119
|
return permuted.reshape(orig_shape)
|
120
120
|
|
121
|
-
return jax.
|
121
|
+
return jax.tree.map(permute_array, data)
|
122
122
|
|
123
123
|
|
124
124
|
def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
|
@@ -133,7 +133,7 @@ def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], r
|
|
133
133
|
return x[tuple(idx_grids)]
|
134
134
|
return x
|
135
135
|
|
136
|
-
return jax.
|
136
|
+
return jax.tree.map(permute_array, data)
|
137
137
|
|
138
138
|
|
139
139
|
TransposeResult = tuple[PyTree, tuple[int, ...], tuple[int, ...]]
|
@@ -215,7 +215,7 @@ def reshuffle_pytree_along_dims(
|
|
215
215
|
transpose_info[path] = (transpose_order, original_shape)
|
216
216
|
return x
|
217
217
|
|
218
|
-
jax.
|
218
|
+
jax.tree.map_with_path(prepare_for_shuffle, data)
|
219
219
|
|
220
220
|
# Create a transposed pytree
|
221
221
|
def get_transposed(path: PathType, x: PyTree) -> PyTree:
|
@@ -223,7 +223,7 @@ def reshuffle_pytree_along_dims(
|
|
223
223
|
return transposed_data[path]
|
224
224
|
return x
|
225
225
|
|
226
|
-
transposed_pytree = jax.
|
226
|
+
transposed_pytree = jax.tree.map_with_path(get_transposed, data)
|
227
227
|
|
228
228
|
# Reshuffle the transposed pytree along the leading dimensions
|
229
229
|
reshuffled_transposed = reshuffle_pytree(transposed_pytree, shape_dims, rng)
|
@@ -235,4 +235,4 @@ def reshuffle_pytree_along_dims(
|
|
235
235
|
return transpose_back(x, transpose_order, original_shape)
|
236
236
|
return x
|
237
237
|
|
238
|
-
return jax.
|
238
|
+
return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
|
@@ -1,10 +1,10 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=bvOBMlEVA46I7ILGfk5AbpwpcdTAjw-4vWI7ci7L7-g,13392
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
7
|
-
xax/core/state.py,sha256=
|
7
|
+
xax/core/state.py,sha256=WwW0qDm-be9MMOT-bGWEFvaWF4iq2FP9xRSn1zq_4A8,2507
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
@@ -14,9 +14,9 @@ xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
|
|
14
14
|
xax/nn/norm.py,sha256=cDmYf5CtyzmuCiWdSP5nr8nZKQOmaZueDQXMPnThg6c,548
|
15
15
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
16
16
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
xax/task/base.py,sha256=
|
17
|
+
xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
|
18
18
|
xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
|
19
|
-
xax/task/script.py,sha256=
|
19
|
+
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
20
20
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
21
21
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
@@ -39,7 +39,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
39
39
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
40
40
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
41
41
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
42
|
-
xax/task/mixins/train.py,sha256=
|
42
|
+
xax/task/mixins/train.py,sha256=jAzc9RD25DbhekvItzsRQQrK9aEwtA_sXy0m2Hfkuxo,24594
|
43
43
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
44
44
|
xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
|
45
45
|
xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
|
@@ -48,7 +48,7 @@ xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
|
48
48
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
49
49
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
50
50
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
51
|
-
xax/utils/pytree.py,sha256=
|
51
|
+
xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
|
52
52
|
xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,17058
|
53
53
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
54
54
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -56,8 +56,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
56
56
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
57
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
58
58
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
59
|
-
xax-0.1.
|
60
|
-
xax-0.1.
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
59
|
+
xax-0.1.10.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
60
|
+
xax-0.1.10.dist-info/METADATA,sha256=kJ1lxZ6cWrtJ5R-adTorzEE_1l0VRJ67xfuBjYXG9Vo,1878
|
61
|
+
xax-0.1.10.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
62
|
+
xax-0.1.10.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
63
|
+
xax-0.1.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|