xax 0.2.13__tar.gz → 0.2.15__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.13/xax.egg-info → xax-0.2.15}/PKG-INFO +1 -1
- {xax-0.2.13 → xax-0.2.15}/xax/__init__.py +15 -5
- xax-0.2.15/xax/nn/metrics.py +92 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/checkpointing.py +81 -54
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/train.py +91 -56
- {xax-0.2.13 → xax-0.2.15}/xax/utils/pytree.py +10 -0
- {xax-0.2.13 → xax-0.2.15/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.13 → xax-0.2.15}/xax.egg-info/SOURCES.txt +1 -1
- xax-0.2.13/xax/nn/norm.py +0 -24
- {xax-0.2.13 → xax-0.2.15}/LICENSE +0 -0
- {xax-0.2.13 → xax-0.2.15}/MANIFEST.in +0 -0
- {xax-0.2.13 → xax-0.2.15}/README.md +0 -0
- {xax-0.2.13 → xax-0.2.15}/pyproject.toml +0 -0
- {xax-0.2.13 → xax-0.2.15}/setup.cfg +0 -0
- {xax-0.2.13 → xax-0.2.15}/setup.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/core/__init__.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/core/conf.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/core/state.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/nn/__init__.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/nn/embeddings.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/nn/equinox.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/nn/export.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/nn/functions.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/nn/geom.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/nn/losses.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/nn/parallel.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/nn/ssm.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/py.typed +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/requirements-dev.txt +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/requirements.txt +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/__init__.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/base.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/launchers/base.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/logger.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/loggers/json.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/loggers/state.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/process.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/script.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/task/task.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/__init__.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/data/collate.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/debugging.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/experiments.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/jax.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/logging.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/numpy.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/profile.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/text.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.13 → xax-0.2.15}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.15"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -51,6 +51,7 @@ __all__ = [
|
|
51
51
|
"rotation_matrix_to_rotation6d",
|
52
52
|
"cross_entropy",
|
53
53
|
"cast_norm_type",
|
54
|
+
"dynamic_time_warping",
|
54
55
|
"get_norm",
|
55
56
|
"is_master",
|
56
57
|
"BaseSSMBlock",
|
@@ -136,6 +137,7 @@ __all__ = [
|
|
136
137
|
"reshuffle_pytree_independently",
|
137
138
|
"slice_array",
|
138
139
|
"slice_pytree",
|
140
|
+
"tuple_insert",
|
139
141
|
"update_pytree",
|
140
142
|
"TextBlock",
|
141
143
|
"camelcase_to_snakecase",
|
@@ -229,8 +231,9 @@ NAME_MAP: dict[str, str] = {
|
|
229
231
|
"rotation6d_to_rotation_matrix": "nn.geom",
|
230
232
|
"rotation_matrix_to_rotation6d": "nn.geom",
|
231
233
|
"cross_entropy": "nn.losses",
|
232
|
-
"cast_norm_type": "nn.
|
233
|
-
"
|
234
|
+
"cast_norm_type": "nn.metrics",
|
235
|
+
"dynamic_time_warping": "nn.metrics",
|
236
|
+
"get_norm": "nn.metrics",
|
234
237
|
"is_master": "nn.parallel",
|
235
238
|
"BaseSSMBlock": "nn.ssm",
|
236
239
|
"DiagSSMBlock": "nn.ssm",
|
@@ -315,6 +318,7 @@ NAME_MAP: dict[str, str] = {
|
|
315
318
|
"reshuffle_pytree_independently": "utils.pytree",
|
316
319
|
"slice_array": "utils.pytree",
|
317
320
|
"slice_pytree": "utils.pytree",
|
321
|
+
"tuple_insert": "utils.pytree",
|
318
322
|
"update_pytree": "utils.pytree",
|
319
323
|
"TextBlock": "utils.text",
|
320
324
|
"camelcase_to_snakecase": "utils.text",
|
@@ -345,7 +349,7 @@ NAME_MAP.update(
|
|
345
349
|
"LOG_ERROR_SUMMARY": "utils.logging",
|
346
350
|
"LOG_PING": "utils.logging",
|
347
351
|
"LOG_STATUS": "utils.logging",
|
348
|
-
"NormType": "nn.
|
352
|
+
"NormType": "nn.metrics",
|
349
353
|
"Output": "task.mixins.output",
|
350
354
|
"Phase": "core.state",
|
351
355
|
"RawConfigType": "task.base",
|
@@ -410,7 +414,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
410
414
|
rotation_matrix_to_rotation6d,
|
411
415
|
)
|
412
416
|
from xax.nn.losses import cross_entropy
|
413
|
-
from xax.nn.
|
417
|
+
from xax.nn.metrics import (
|
418
|
+
NormType,
|
419
|
+
cast_norm_type,
|
420
|
+
dynamic_time_warping,
|
421
|
+
get_norm,
|
422
|
+
)
|
414
423
|
from xax.nn.parallel import is_master
|
415
424
|
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
416
425
|
from xax.task.base import RawConfigType
|
@@ -495,6 +504,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
495
504
|
reshuffle_pytree_independently,
|
496
505
|
slice_array,
|
497
506
|
slice_pytree,
|
507
|
+
tuple_insert,
|
498
508
|
update_pytree,
|
499
509
|
)
|
500
510
|
from xax.utils.text import (
|
@@ -0,0 +1,92 @@
|
|
1
|
+
"""Norm and metric utilities."""
|
2
|
+
|
3
|
+
from typing import Literal, cast, get_args, overload
|
4
|
+
|
5
|
+
import chex
|
6
|
+
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
from jaxtyping import Array
|
9
|
+
|
10
|
+
from xax.utils.jax import jit as xax_jit
|
11
|
+
|
12
|
+
NormType = Literal["l1", "l2"]
|
13
|
+
|
14
|
+
|
15
|
+
def cast_norm_type(norm: str) -> NormType:
|
16
|
+
if norm not in get_args(NormType):
|
17
|
+
raise ValueError(f"Invalid norm: {norm}")
|
18
|
+
return cast(NormType, norm)
|
19
|
+
|
20
|
+
|
21
|
+
def get_norm(x: Array, norm: NormType) -> Array:
|
22
|
+
match norm:
|
23
|
+
case "l1":
|
24
|
+
return jnp.abs(x)
|
25
|
+
case "l2":
|
26
|
+
return jnp.square(x)
|
27
|
+
case _:
|
28
|
+
raise ValueError(f"Invalid norm: {norm}")
|
29
|
+
|
30
|
+
|
31
|
+
@overload
|
32
|
+
def dynamic_time_warping(distance_matrix_nm: Array) -> Array: ...
|
33
|
+
|
34
|
+
|
35
|
+
@overload
|
36
|
+
def dynamic_time_warping(distance_matrix_nm: Array, return_path: Literal[True]) -> tuple[Array, Array]: ...
|
37
|
+
|
38
|
+
|
39
|
+
@xax_jit(static_argnames=["return_path"])
|
40
|
+
def dynamic_time_warping(distance_matrix_nm: Array, return_path: bool = False) -> Array | tuple[Array, Array]:
|
41
|
+
"""Dynamic Time Warping.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
distance_matrix_nm: A matrix of pairwise distances between two
|
45
|
+
sequences, with shape (N, M), with the condition that N <= M.
|
46
|
+
return_path: If set, return the minimum path, otherwise just return
|
47
|
+
the cost. The latter is preferred if using this function as a
|
48
|
+
distance metric since it avoids the backwards scan on backpointers.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
The cost of the minimum path from the top-left corner of the distance
|
52
|
+
matrix to the bottom-right corner, along with the indices of that
|
53
|
+
minimum path.
|
54
|
+
"""
|
55
|
+
chex.assert_shape(distance_matrix_nm, (None, None))
|
56
|
+
n, m = distance_matrix_nm.shape
|
57
|
+
|
58
|
+
assert n <= m, f"Invalid dynamic time warping distance matrix shape: ({n}, {m})"
|
59
|
+
|
60
|
+
# Masks values which cannot be reached.
|
61
|
+
row_idx = jnp.arange(n)[:, None]
|
62
|
+
col_idx = jnp.arange(m)[None, :]
|
63
|
+
mask = row_idx > col_idx
|
64
|
+
distance_matrix_nm = jnp.where(mask, jnp.inf, distance_matrix_nm)
|
65
|
+
|
66
|
+
# Pre-pads with inf
|
67
|
+
distance_matrix_nm = jnp.pad(distance_matrix_nm, ((1, 0), (0, 0)), mode="constant", constant_values=jnp.inf)
|
68
|
+
indices = jnp.arange(n)
|
69
|
+
|
70
|
+
# Scan over remaining rows to fill cost matrix
|
71
|
+
def scan_fn(prev_cost: Array, cur_distances: Array) -> tuple[Array, Array]:
|
72
|
+
same_trans = prev_cost
|
73
|
+
prev_trans = jnp.pad(prev_cost[:-1], ((1, 0),), mode="constant", constant_values=jnp.inf)
|
74
|
+
nc = jnp.minimum(prev_trans, same_trans) + cur_distances[1:]
|
75
|
+
return nc, jnp.where(prev_trans < same_trans, indices - 1, indices) if return_path else nc
|
76
|
+
|
77
|
+
init_cost = distance_matrix_nm[1:, 0]
|
78
|
+
final_cost, back_pointers = jax.lax.scan(scan_fn, init_cost, distance_matrix_nm[:, 1:].T)
|
79
|
+
|
80
|
+
if not return_path:
|
81
|
+
return final_cost
|
82
|
+
|
83
|
+
# Scan the back pointers backwards to get the minimum path.
|
84
|
+
def scan_back_fn(carry: Array, back_pointer: Array) -> tuple[Array, Array]:
|
85
|
+
prev_idx = back_pointer[carry]
|
86
|
+
return prev_idx, carry
|
87
|
+
|
88
|
+
final_index = jnp.array(n - 1)
|
89
|
+
_, min_path = jax.lax.scan(scan_back_fn, final_index, back_pointers, reverse=True)
|
90
|
+
min_path = jnp.pad(min_path, ((1, 0)), mode="constant", constant_values=0)
|
91
|
+
|
92
|
+
return final_cost[-1], min_path
|
@@ -6,7 +6,7 @@ import logging
|
|
6
6
|
import tarfile
|
7
7
|
from dataclasses import dataclass
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import Generic, Literal, TypeVar, cast, overload
|
9
|
+
from typing import Generic, Literal, Sequence, TypeVar, cast, overload
|
10
10
|
|
11
11
|
import equinox as eqx
|
12
12
|
import jax
|
@@ -57,10 +57,10 @@ def load_ckpt(
|
|
57
57
|
path: Path,
|
58
58
|
*,
|
59
59
|
part: Literal["all"],
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
60
|
+
model_templates: Sequence[PyTree],
|
61
|
+
optimizer_templates: Sequence[optax.GradientTransformation],
|
62
|
+
opt_state_templates: Sequence[optax.OptState],
|
63
|
+
) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, DictConfig]: ...
|
64
64
|
|
65
65
|
|
66
66
|
@overload
|
@@ -68,20 +68,35 @@ def load_ckpt(
|
|
68
68
|
path: Path,
|
69
69
|
*,
|
70
70
|
part: Literal["model_state_config"],
|
71
|
-
|
72
|
-
) -> tuple[PyTree, State, DictConfig]: ...
|
71
|
+
model_templates: Sequence[PyTree],
|
72
|
+
) -> tuple[list[PyTree], State, DictConfig]: ...
|
73
73
|
|
74
74
|
|
75
75
|
@overload
|
76
|
-
def load_ckpt(
|
76
|
+
def load_ckpt(
|
77
|
+
path: Path,
|
78
|
+
*,
|
79
|
+
part: Literal["model"],
|
80
|
+
model_templates: Sequence[PyTree],
|
81
|
+
) -> list[PyTree]: ...
|
77
82
|
|
78
83
|
|
79
84
|
@overload
|
80
|
-
def load_ckpt(
|
85
|
+
def load_ckpt(
|
86
|
+
path: Path,
|
87
|
+
*,
|
88
|
+
part: Literal["opt"],
|
89
|
+
optimizer_templates: Sequence[optax.GradientTransformation],
|
90
|
+
) -> list[optax.GradientTransformation]: ...
|
81
91
|
|
82
92
|
|
83
93
|
@overload
|
84
|
-
def load_ckpt(
|
94
|
+
def load_ckpt(
|
95
|
+
path: Path,
|
96
|
+
*,
|
97
|
+
part: Literal["opt_state"],
|
98
|
+
opt_state_templates: Sequence[optax.OptState],
|
99
|
+
) -> list[optax.OptState]: ...
|
85
100
|
|
86
101
|
|
87
102
|
@overload
|
@@ -96,40 +111,49 @@ def load_ckpt(
|
|
96
111
|
path: str | Path,
|
97
112
|
*,
|
98
113
|
part: CheckpointPart = "model",
|
99
|
-
|
100
|
-
|
101
|
-
|
114
|
+
model_templates: Sequence[PyTree] | None = None,
|
115
|
+
optimizer_templates: Sequence[optax.GradientTransformation] | None = None,
|
116
|
+
opt_state_templates: Sequence[optax.OptState] | None = None,
|
102
117
|
) -> (
|
103
|
-
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
104
|
-
| tuple[PyTree, State, DictConfig]
|
105
|
-
| PyTree
|
106
|
-
| optax.GradientTransformation
|
107
|
-
| optax.OptState
|
118
|
+
tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, DictConfig]
|
119
|
+
| tuple[list[PyTree], State, DictConfig]
|
120
|
+
| list[PyTree]
|
121
|
+
| list[optax.GradientTransformation]
|
122
|
+
| list[optax.OptState]
|
108
123
|
| State
|
109
124
|
| DictConfig
|
110
125
|
):
|
111
126
|
with tarfile.open(path, "r:gz") as tar:
|
112
127
|
|
113
|
-
def get_model() -> PyTree:
|
114
|
-
if
|
128
|
+
def get_model() -> list[PyTree]:
|
129
|
+
if model_templates is None:
|
115
130
|
raise ValueError("model_template must be provided to load model weights")
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
131
|
+
models: list[PyTree] = []
|
132
|
+
for i, model_template in enumerate(model_templates):
|
133
|
+
if (model := tar.extractfile(f"model_{i}")) is None:
|
134
|
+
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
135
|
+
models.append(eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template))
|
136
|
+
return models
|
137
|
+
|
138
|
+
def get_opt() -> list[optax.GradientTransformation]:
|
139
|
+
if optimizer_templates is None:
|
122
140
|
raise ValueError("optimizer_template must be provided to load optimizer")
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
141
|
+
opts: list[optax.GradientTransformation] = []
|
142
|
+
for i, optimizer_template in enumerate(optimizer_templates):
|
143
|
+
if (opt := tar.extractfile(f"optimizer_{i}")) is None:
|
144
|
+
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
145
|
+
opts.append(eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template))
|
146
|
+
return opts
|
147
|
+
|
148
|
+
def get_opt_state() -> list[optax.OptState]:
|
149
|
+
if opt_state_templates is None:
|
129
150
|
raise ValueError("opt_state_template must be provided to load optimizer state")
|
130
|
-
|
131
|
-
|
132
|
-
|
151
|
+
opt_states: list[optax.OptState] = []
|
152
|
+
for i, opt_state_template in enumerate(opt_state_templates):
|
153
|
+
if (opt_state := tar.extractfile(f"opt_state_{i}")) is None:
|
154
|
+
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
155
|
+
opt_states.append(eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template))
|
156
|
+
return opt_states
|
133
157
|
|
134
158
|
def get_state() -> State:
|
135
159
|
if (state := tar.extractfile("state")) is None:
|
@@ -192,20 +216,20 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
192
216
|
|
193
217
|
def save_checkpoint(
|
194
218
|
self,
|
195
|
-
|
196
|
-
|
197
|
-
|
219
|
+
models: Sequence[PyTree] | None = None,
|
220
|
+
optimizers: Sequence[optax.GradientTransformation] | None = None,
|
221
|
+
opt_states: Sequence[optax.OptState] | None = None,
|
198
222
|
aux_data: PyTree | None = None,
|
199
223
|
state: State | None = None,
|
200
224
|
) -> Path:
|
201
225
|
"""Save a checkpoint.
|
202
226
|
|
203
227
|
Args:
|
204
|
-
|
205
|
-
|
206
|
-
|
228
|
+
models: The models to save
|
229
|
+
optimizers: The optimizers to save
|
230
|
+
opt_states: The optimizer states to save
|
207
231
|
aux_data: Additional data to save
|
208
|
-
|
232
|
+
state: The current training state
|
209
233
|
|
210
234
|
Returns:
|
211
235
|
Path to the saved checkpoint
|
@@ -235,22 +259,25 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
235
259
|
tar.addfile(tarinfo, buf)
|
236
260
|
|
237
261
|
# Save model using Equinox
|
238
|
-
if
|
239
|
-
|
240
|
-
|
241
|
-
|
262
|
+
if models is not None:
|
263
|
+
for i, model in enumerate(models):
|
264
|
+
with io.BytesIO() as buf:
|
265
|
+
eqx.tree_serialise_leaves(buf, model)
|
266
|
+
add_file(f"model_{i}", buf)
|
242
267
|
|
243
268
|
# Save optimizer using Equinox
|
244
|
-
if
|
245
|
-
|
246
|
-
|
247
|
-
|
269
|
+
if optimizers is not None:
|
270
|
+
for i, optimizer in enumerate(optimizers):
|
271
|
+
with io.BytesIO() as buf:
|
272
|
+
eqx.tree_serialise_leaves(buf, optimizer)
|
273
|
+
add_file(f"optimizer_{i}", buf)
|
248
274
|
|
249
275
|
# Save optimizer state using Equinox
|
250
|
-
if
|
251
|
-
|
252
|
-
|
253
|
-
|
276
|
+
if opt_states is not None:
|
277
|
+
for i, opt_state in enumerate(opt_states):
|
278
|
+
with io.BytesIO() as buf:
|
279
|
+
eqx.tree_serialise_leaves(buf, opt_state)
|
280
|
+
add_file(f"opt_state_{i}", buf)
|
254
281
|
|
255
282
|
# Save aux data using Equinox.
|
256
283
|
if aux_data is not None:
|
@@ -310,23 +310,46 @@ class TrainMixin(
|
|
310
310
|
self.write_logs(state)
|
311
311
|
|
312
312
|
@abstractmethod
|
313
|
-
def get_model(self, key: PRNGKeyArray) -> PyTree:
|
313
|
+
def get_model(self, key: PRNGKeyArray) -> PyTree | Sequence[PyTree]:
|
314
314
|
"""Returns the Equinox model to train.
|
315
315
|
|
316
316
|
Returns:
|
317
317
|
The model to train.
|
318
318
|
"""
|
319
319
|
|
320
|
+
def _get_models(self, key: PRNGKeyArray) -> list[PyTree]:
|
321
|
+
models = self.get_model(key)
|
322
|
+
if isinstance(models, Sequence):
|
323
|
+
models = list(models)
|
324
|
+
elif isinstance(models, eqx.Module):
|
325
|
+
models = [models]
|
326
|
+
else:
|
327
|
+
logger.warning("Model is not a sequence or an eqx.Module, wrapping it in a list anyway")
|
328
|
+
models = [models]
|
329
|
+
return models
|
330
|
+
|
320
331
|
@abstractmethod
|
321
|
-
def get_optimizer(self) -> optax.GradientTransformation:
|
332
|
+
def get_optimizer(self) -> optax.GradientTransformation | Sequence[optax.GradientTransformation]:
|
322
333
|
"""Gets the optimizer for the model.
|
323
334
|
|
324
335
|
Returns:
|
325
336
|
The optimizer to use to train the model.
|
326
337
|
"""
|
327
338
|
|
328
|
-
def
|
329
|
-
|
339
|
+
def _get_optimizers(self) -> list[optax.GradientTransformation]:
|
340
|
+
optimizers = self.get_optimizer()
|
341
|
+
if isinstance(optimizers, optax.GradientTransformation):
|
342
|
+
optimizers = [optimizers]
|
343
|
+
elif isinstance(optimizers, Sequence):
|
344
|
+
optimizers = list(optimizers)
|
345
|
+
return optimizers
|
346
|
+
|
347
|
+
def get_initial_opt_state(
|
348
|
+
self,
|
349
|
+
models: list[PyTree],
|
350
|
+
optimizers: list[optax.GradientTransformation],
|
351
|
+
) -> list[optax.OptState]:
|
352
|
+
return [opt.init(eqx.filter(model, eqx.is_array)) for model, opt in zip(models, optimizers, strict=True)]
|
330
353
|
|
331
354
|
@overload
|
332
355
|
def load_initial_state(
|
@@ -340,13 +363,16 @@ class TrainMixin(
|
|
340
363
|
self,
|
341
364
|
key: PRNGKeyArray,
|
342
365
|
load_optimizer: Literal[True],
|
343
|
-
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State]: ...
|
366
|
+
) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State]: ...
|
344
367
|
|
345
368
|
def load_initial_state(
|
346
369
|
self,
|
347
370
|
key: PRNGKeyArray,
|
348
371
|
load_optimizer: bool = False,
|
349
|
-
) ->
|
372
|
+
) -> (
|
373
|
+
tuple[list[PyTree], State]
|
374
|
+
| tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State]
|
375
|
+
):
|
350
376
|
init_ckpt_path = self.get_init_ckpt_path()
|
351
377
|
|
352
378
|
if init_ckpt_path is not None:
|
@@ -364,16 +390,17 @@ class TrainMixin(
|
|
364
390
|
return model, optimizer, opt_state, state
|
365
391
|
|
366
392
|
logger.info("Starting a new training run")
|
367
|
-
|
393
|
+
models = self._get_models(key)
|
368
394
|
state = State.init_state()
|
369
395
|
|
370
396
|
if not load_optimizer:
|
371
|
-
return
|
397
|
+
return models, state
|
372
398
|
|
373
|
-
optimizer
|
374
|
-
|
399
|
+
# Gets the optimizer(s) for the model.
|
400
|
+
optimizers = self._get_optimizers()
|
401
|
+
opt_states = self.get_initial_opt_state(models, optimizers)
|
375
402
|
|
376
|
-
return
|
403
|
+
return models, optimizers, opt_states, state
|
377
404
|
|
378
405
|
@overload
|
379
406
|
def load_ckpt(
|
@@ -381,7 +408,7 @@ class TrainMixin(
|
|
381
408
|
path: Path,
|
382
409
|
*,
|
383
410
|
part: Literal["all"],
|
384
|
-
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
411
|
+
) -> tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, Config]: ...
|
385
412
|
|
386
413
|
@overload
|
387
414
|
def load_ckpt(
|
@@ -389,7 +416,7 @@ class TrainMixin(
|
|
389
416
|
path: Path,
|
390
417
|
*,
|
391
418
|
part: Literal["model_state_config"],
|
392
|
-
) -> tuple[PyTree, State, Config]: ...
|
419
|
+
) -> tuple[list[PyTree], State, Config]: ...
|
393
420
|
|
394
421
|
@overload
|
395
422
|
def load_ckpt(
|
@@ -397,7 +424,7 @@ class TrainMixin(
|
|
397
424
|
path: Path,
|
398
425
|
*,
|
399
426
|
part: Literal["model"],
|
400
|
-
) -> PyTree: ...
|
427
|
+
) -> list[PyTree]: ...
|
401
428
|
|
402
429
|
@overload
|
403
430
|
def load_ckpt(
|
@@ -405,7 +432,7 @@ class TrainMixin(
|
|
405
432
|
path: Path,
|
406
433
|
*,
|
407
434
|
part: Literal["opt"],
|
408
|
-
) -> optax.GradientTransformation: ...
|
435
|
+
) -> list[optax.GradientTransformation]: ...
|
409
436
|
|
410
437
|
@overload
|
411
438
|
def load_ckpt(
|
@@ -415,7 +442,7 @@ class TrainMixin(
|
|
415
442
|
part: Literal["opt_state"],
|
416
443
|
model: PyTree | None = None,
|
417
444
|
optimizer: optax.GradientTransformation | None = None,
|
418
|
-
) -> optax.OptState: ...
|
445
|
+
) -> list[optax.OptState]: ...
|
419
446
|
|
420
447
|
@overload
|
421
448
|
def load_ckpt(
|
@@ -423,7 +450,7 @@ class TrainMixin(
|
|
423
450
|
path: Path,
|
424
451
|
*,
|
425
452
|
part: Literal["state"],
|
426
|
-
) -> State: ...
|
453
|
+
) -> list[State]: ...
|
427
454
|
|
428
455
|
@overload
|
429
456
|
def load_ckpt(
|
@@ -431,7 +458,7 @@ class TrainMixin(
|
|
431
458
|
path: Path,
|
432
459
|
*,
|
433
460
|
part: Literal["config"],
|
434
|
-
) -> Config: ...
|
461
|
+
) -> list[Config]: ...
|
435
462
|
|
436
463
|
def load_ckpt(
|
437
464
|
self,
|
@@ -441,11 +468,11 @@ class TrainMixin(
|
|
441
468
|
model: PyTree | None = None,
|
442
469
|
optimizer: optax.GradientTransformation | None = None,
|
443
470
|
) -> (
|
444
|
-
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
445
|
-
| tuple[PyTree, State, Config]
|
446
|
-
| PyTree
|
447
|
-
| optax.GradientTransformation
|
448
|
-
| optax.OptState
|
471
|
+
tuple[list[PyTree], list[optax.GradientTransformation], list[optax.OptState], State, Config]
|
472
|
+
| tuple[list[PyTree], State, Config]
|
473
|
+
| list[PyTree]
|
474
|
+
| list[optax.GradientTransformation]
|
475
|
+
| list[optax.OptState]
|
449
476
|
| State
|
450
477
|
| Config
|
451
478
|
):
|
@@ -456,28 +483,28 @@ class TrainMixin(
|
|
456
483
|
|
457
484
|
match part:
|
458
485
|
case "model_state_config":
|
459
|
-
|
460
|
-
model, state, config = load_ckpt(path, part="model_state_config",
|
486
|
+
model_specs = eqx.filter_eval_shape(self._get_models, key)
|
487
|
+
model, state, config = load_ckpt(path, part="model_state_config", model_templates=model_specs)
|
461
488
|
config = self.get_config(config, use_cli=False)
|
462
489
|
return model, state, config
|
463
490
|
|
464
491
|
case "model":
|
465
|
-
|
466
|
-
return load_ckpt(path, part="model",
|
492
|
+
model_specs = eqx.filter_eval_shape(self._get_models, key)
|
493
|
+
return load_ckpt(path, part="model", model_templates=model_specs)
|
467
494
|
|
468
495
|
case "opt":
|
469
|
-
|
470
|
-
return load_ckpt(path, part="opt",
|
496
|
+
optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
|
497
|
+
return load_ckpt(path, part="opt", optimizer_templates=optimizer_specs)
|
471
498
|
|
472
499
|
case "opt_state":
|
473
500
|
if model is None:
|
474
|
-
|
475
|
-
model = load_ckpt(path, part="model",
|
501
|
+
model_specs = eqx.filter_eval_shape(self._get_models, key)
|
502
|
+
model = load_ckpt(path, part="model", model_templates=model_specs)
|
476
503
|
if optimizer is None:
|
477
|
-
|
478
|
-
optimizer = load_ckpt(path, part="opt",
|
479
|
-
|
480
|
-
return load_ckpt(path, part="opt_state",
|
504
|
+
optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
|
505
|
+
optimizer = load_ckpt(path, part="opt", optimizer_templates=optimizer_specs)
|
506
|
+
opt_state_specs = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
507
|
+
return load_ckpt(path, part="opt_state", opt_state_templates=opt_state_specs)
|
481
508
|
|
482
509
|
case "state":
|
483
510
|
return load_ckpt(path, part="state")
|
@@ -486,12 +513,12 @@ class TrainMixin(
|
|
486
513
|
return self.get_config(load_ckpt(path, part="config"), use_cli=False)
|
487
514
|
|
488
515
|
case "all":
|
489
|
-
|
490
|
-
model = load_ckpt(path, part="model",
|
491
|
-
|
492
|
-
optimizer = load_ckpt(path, part="opt",
|
493
|
-
|
494
|
-
opt_state = load_ckpt(path, part="opt_state",
|
516
|
+
model_specs = eqx.filter_eval_shape(self._get_models, key)
|
517
|
+
model = load_ckpt(path, part="model", model_templates=model_specs)
|
518
|
+
optimizer_specs = eqx.filter_eval_shape(self._get_optimizers)
|
519
|
+
optimizer = load_ckpt(path, part="opt", optimizer_templates=optimizer_specs)
|
520
|
+
opt_state_specs = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
521
|
+
opt_state = load_ckpt(path, part="opt_state", opt_state_templates=opt_state_specs)
|
495
522
|
state = load_ckpt(path, part="state")
|
496
523
|
config = self.get_config(load_ckpt(path, part="config"), use_cli=False)
|
497
524
|
return model, optimizer, opt_state, state, config
|
@@ -718,14 +745,22 @@ class TrainMixin(
|
|
718
745
|
|
719
746
|
def train_loop(
|
720
747
|
self,
|
721
|
-
|
722
|
-
|
723
|
-
|
748
|
+
models: Sequence[PyTree],
|
749
|
+
optimizers: Sequence[optax.GradientTransformation],
|
750
|
+
opt_states: Sequence[optax.OptState],
|
724
751
|
train_pf: Iterator[Batch],
|
725
752
|
valid_pf: Iterator[Batch],
|
726
753
|
state: State,
|
727
754
|
) -> None:
|
728
|
-
|
755
|
+
if len(models) != 1 or len(optimizers) != 1 or len(opt_states) != 1:
|
756
|
+
raise ValueError(
|
757
|
+
"Vanilla training expects a single model, optimizer and optimizer state. "
|
758
|
+
f"Found {len(models)} models, {len(optimizers)} optimizers and {len(opt_states)} optimizer states."
|
759
|
+
)
|
760
|
+
|
761
|
+
model_arr, model_static = eqx.partition(models[0], self.model_partition_fn)
|
762
|
+
optimizer = optimizers[0]
|
763
|
+
opt_state = opt_states[0]
|
729
764
|
|
730
765
|
while not self.is_training_over(state):
|
731
766
|
valid_step = self.valid_step_timer(state)
|
@@ -773,11 +808,11 @@ class TrainMixin(
|
|
773
808
|
|
774
809
|
if self.should_checkpoint(state):
|
775
810
|
model = eqx.combine(model_arr, model_static)
|
776
|
-
self.save_checkpoint(
|
811
|
+
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
777
812
|
|
778
813
|
# After finishing training, save the final checkpoint.
|
779
814
|
model = eqx.combine(model_arr, model_static)
|
780
|
-
self.save_checkpoint(
|
815
|
+
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
781
816
|
|
782
817
|
@contextlib.contextmanager
|
783
818
|
def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
@@ -841,14 +876,14 @@ class TrainMixin(
|
|
841
876
|
Thread(target=self.log_state, daemon=True).start()
|
842
877
|
|
843
878
|
key, model_key = jax.random.split(key)
|
844
|
-
|
845
|
-
logger.info("Model size: %s", f"{get_pytree_param_count(
|
846
|
-
logger.info("Optimizer size: %s", f"{get_pytree_param_count(
|
879
|
+
models, optimizers, opt_states, state = self.load_initial_state(model_key, load_optimizer=True)
|
880
|
+
logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
|
881
|
+
logger.info("Optimizer size: %s", f"{get_pytree_param_count(optimizers):,}")
|
847
882
|
|
848
883
|
state = self.on_training_start(state)
|
849
884
|
|
850
885
|
def on_exit() -> None:
|
851
|
-
self.save_checkpoint(
|
886
|
+
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
852
887
|
|
853
888
|
# Handle user-defined interrupts during the training loop.
|
854
889
|
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
@@ -857,9 +892,9 @@ class TrainMixin(
|
|
857
892
|
with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
|
858
893
|
try:
|
859
894
|
self.train_loop(
|
860
|
-
|
861
|
-
|
862
|
-
|
895
|
+
models=models,
|
896
|
+
optimizers=optimizers,
|
897
|
+
opt_states=opt_states,
|
863
898
|
train_pf=train_pf,
|
864
899
|
valid_pf=valid_pf,
|
865
900
|
state=state,
|
@@ -869,7 +904,7 @@ class TrainMixin(
|
|
869
904
|
if is_master():
|
870
905
|
num_steps, num_samples = int(state.num_steps), int(state.num_samples)
|
871
906
|
show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
|
872
|
-
self.save_checkpoint(
|
907
|
+
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
873
908
|
|
874
909
|
except (KeyboardInterrupt, bdb.BdbQuit):
|
875
910
|
if is_master():
|
@@ -879,7 +914,7 @@ class TrainMixin(
|
|
879
914
|
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
880
915
|
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
881
916
|
sys.stdout.flush()
|
882
|
-
self.save_checkpoint(
|
917
|
+
self.save_checkpoint(models=models, optimizers=optimizers, opt_states=opt_states, state=state)
|
883
918
|
|
884
919
|
finally:
|
885
920
|
state = self.on_training_end(state)
|
@@ -1,5 +1,7 @@
|
|
1
1
|
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
2
|
|
3
|
+
from typing import TypeVar
|
4
|
+
|
3
5
|
import chex
|
4
6
|
import equinox as eqx
|
5
7
|
import jax
|
@@ -7,6 +9,8 @@ import jax.numpy as jnp
|
|
7
9
|
from jax import Array
|
8
10
|
from jaxtyping import PRNGKeyArray, PyTree
|
9
11
|
|
12
|
+
T = TypeVar("T")
|
13
|
+
|
10
14
|
|
11
15
|
def slice_array(x: Array, start: Array, slice_length: int) -> Array:
|
12
16
|
"""Get a slice of an array along the first dimension.
|
@@ -243,3 +247,9 @@ def get_pytree_param_count(pytree: PyTree) -> int:
|
|
243
247
|
"""Calculates the total number of parameters in a PyTree."""
|
244
248
|
leaves, _ = jax.tree.flatten(pytree)
|
245
249
|
return sum(x.size for x in leaves if isinstance(x, jnp.ndarray) and eqx.is_inexact_array(x))
|
250
|
+
|
251
|
+
|
252
|
+
def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
|
253
|
+
mut = list(t)
|
254
|
+
mut[index] = value
|
255
|
+
return tuple(mut)
|
xax-0.2.13/xax/nn/norm.py
DELETED
@@ -1,24 +0,0 @@
|
|
1
|
-
"""Normalization utilities."""
|
2
|
-
|
3
|
-
from typing import Literal, cast, get_args
|
4
|
-
|
5
|
-
import jax.numpy as jnp
|
6
|
-
from jaxtyping import Array
|
7
|
-
|
8
|
-
NormType = Literal["l1", "l2"]
|
9
|
-
|
10
|
-
|
11
|
-
def cast_norm_type(norm: str) -> NormType:
|
12
|
-
if norm not in get_args(NormType):
|
13
|
-
raise ValueError(f"Invalid norm: {norm}")
|
14
|
-
return cast(NormType, norm)
|
15
|
-
|
16
|
-
|
17
|
-
def get_norm(x: Array, norm: NormType) -> Array:
|
18
|
-
match norm:
|
19
|
-
case "l1":
|
20
|
-
return jnp.abs(x)
|
21
|
-
case "l2":
|
22
|
-
return jnp.square(x)
|
23
|
-
case _:
|
24
|
-
raise ValueError(f"Invalid norm: {norm}")
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|