xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +122 -8
- xax/core/conf.py +9 -33
- xax/core/state.py +13 -23
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +17 -10
- xax/task/base.py +2 -6
- xax/task/logger.py +419 -412
- xax/task/loggers/callback.py +44 -0
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +16 -33
- xax/task/mixins/__init__.py +3 -1
- xax/task/mixins/artifacts.py +19 -9
- xax/task/mixins/checkpointing.py +221 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +26 -15
- xax/task/mixins/data_loader.py +27 -19
- xax/task/mixins/gpu_stats.py +22 -8
- xax/task/mixins/logger.py +5 -251
- xax/task/mixins/process.py +8 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +236 -145
- xax/task/script.py +1 -1
- xax/task/task.py +13 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +89 -21
- xax-0.0.6.dist-info/METADATA +50 -0
- xax-0.0.6.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.3.dist-info/METADATA +0 -39
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/top_level.txt +0 -0
xax/nn/embeddings.py
ADDED
@@ -0,0 +1,355 @@
|
|
1
|
+
"""Defines embedding layers."""
|
2
|
+
|
3
|
+
import math
|
4
|
+
from typing import Literal, cast, get_args, overload
|
5
|
+
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import jax.random as jrandom
|
10
|
+
from jaxtyping import Array, DTypeLike, PRNGKeyArray
|
11
|
+
|
12
|
+
EmbeddingKind = Literal["identity", "learned", "sinusoidal", "rotary"]
|
13
|
+
|
14
|
+
|
15
|
+
def cast_embedding_kind(k: str) -> EmbeddingKind:
|
16
|
+
args = get_args(EmbeddingKind)
|
17
|
+
assert k in args, f"Invalid initialization type: '{k}' Valid options are {args}"
|
18
|
+
return cast(EmbeddingKind, k)
|
19
|
+
|
20
|
+
|
21
|
+
class IdentityPositionalEmbeddings(eqx.Module):
|
22
|
+
def __call__(self, x: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
23
|
+
return x
|
24
|
+
|
25
|
+
|
26
|
+
class LearnedPositionalEmbeddings(eqx.Module):
|
27
|
+
"""Defines a learned embeddings module.
|
28
|
+
|
29
|
+
Parameters:
|
30
|
+
max_tsz: The maximum sequence length.
|
31
|
+
embed_dim: The embedding dimension.
|
32
|
+
weight_init: The initialization type for the embedding weight.
|
33
|
+
learnable: Whether the embeddings are learnable.
|
34
|
+
"""
|
35
|
+
|
36
|
+
max_tsz: int = eqx.field(static=True)
|
37
|
+
embed_dim: int = eqx.field(static=True)
|
38
|
+
learnable: bool = eqx.field(static=True)
|
39
|
+
embeddings_tc: Array
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
max_tsz: int,
|
44
|
+
embed_dim: int,
|
45
|
+
learnable: bool = True,
|
46
|
+
*,
|
47
|
+
key: PRNGKeyArray,
|
48
|
+
) -> None:
|
49
|
+
super().__init__()
|
50
|
+
|
51
|
+
self.max_tsz = max_tsz
|
52
|
+
self.embed_dim = embed_dim
|
53
|
+
self.learnable = learnable
|
54
|
+
|
55
|
+
self.embeddings_tc = jrandom.normal(key, (max_tsz, embed_dim))
|
56
|
+
|
57
|
+
def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
58
|
+
if times_t is None:
|
59
|
+
emb_tc = self.embeddings_tc[offset : offset + x_tc.shape[-2]]
|
60
|
+
else:
|
61
|
+
emb_tc = self.embeddings_tc[times_t]
|
62
|
+
if not self.learnable:
|
63
|
+
emb_tc = jax.lax.stop_gradient(emb_tc)
|
64
|
+
return emb_tc
|
65
|
+
|
66
|
+
|
67
|
+
class SinusoidalEmbeddings(eqx.Module):
|
68
|
+
"""Defines a sinusoidal embeddings module.
|
69
|
+
|
70
|
+
Parameters:
|
71
|
+
embed_dim: The embedding dimension.
|
72
|
+
max_tsz: The maximum sequence length.
|
73
|
+
learnable: Whether the embeddings are learnable.
|
74
|
+
base: The base for the sinusoidal embeddings.
|
75
|
+
"""
|
76
|
+
|
77
|
+
base: int = eqx.field(static=True)
|
78
|
+
max_tsz: int | None = eqx.field(static=True)
|
79
|
+
embed_dim: int | None = eqx.field(static=True)
|
80
|
+
embeddings_tc: Array | None
|
81
|
+
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
embed_dim: int | None = None,
|
85
|
+
max_tsz: int | None = None,
|
86
|
+
learnable: bool = True,
|
87
|
+
base: int = 10_000,
|
88
|
+
) -> None:
|
89
|
+
super().__init__()
|
90
|
+
|
91
|
+
self.max_tsz = max_tsz
|
92
|
+
self.embed_dim = embed_dim
|
93
|
+
self.base = base
|
94
|
+
|
95
|
+
self.embeddings_tc: Array | None = None
|
96
|
+
if learnable:
|
97
|
+
assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
|
98
|
+
assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
|
99
|
+
self.embeddings_tc = self.get_embeddings(max_tsz, embed_dim)
|
100
|
+
|
101
|
+
def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
102
|
+
tsz, dims = x_tc.shape
|
103
|
+
|
104
|
+
# If the embeddings are learnable, use the property.
|
105
|
+
if self.embeddings_tc is None:
|
106
|
+
if times_t is None:
|
107
|
+
embeddings_tc = self.get_embeddings(offset + tsz, dims, x_tc.dtype)
|
108
|
+
else:
|
109
|
+
embeddings_tc = self.get_embeddings(times_t.max().item() + 1, dims, x_tc.dtype)
|
110
|
+
else:
|
111
|
+
embeddings_tc = self.embeddings_tc
|
112
|
+
|
113
|
+
# Get only the embeddings for the specified time steps.
|
114
|
+
if times_t is None:
|
115
|
+
embeddings_tc = embeddings_tc[offset : offset + tsz]
|
116
|
+
else:
|
117
|
+
embeddings_tc = embeddings_tc[times_t]
|
118
|
+
|
119
|
+
return x_tc + embeddings_tc
|
120
|
+
|
121
|
+
def get_embeddings(
|
122
|
+
self,
|
123
|
+
tsz: int,
|
124
|
+
embed_dim: int,
|
125
|
+
dtype: DTypeLike | None = None,
|
126
|
+
) -> Array:
|
127
|
+
positions_t = jax.numpy.arange(tsz, dtype=dtype)
|
128
|
+
dim_d = jax.numpy.arange(embed_dim, dtype=dtype)
|
129
|
+
dim_d = self.base ** (2 * (dim_d // 2) / embed_dim)
|
130
|
+
embeddings_td = positions_t[:, None] / dim_d[None, :]
|
131
|
+
embeddings_td = jnp.concatenate(
|
132
|
+
[jax.numpy.sin(embeddings_td[:, 0::2]), jax.numpy.cos(embeddings_td[:, 1::2])],
|
133
|
+
axis=-1,
|
134
|
+
)
|
135
|
+
return embeddings_td.astype(dtype)
|
136
|
+
|
137
|
+
|
138
|
+
def get_rotary_embeddings(
|
139
|
+
tsz: int,
|
140
|
+
embed_dim: int,
|
141
|
+
dtype: jnp.dtype,
|
142
|
+
offset: int = 0,
|
143
|
+
base: int = 10_000,
|
144
|
+
) -> Array:
|
145
|
+
assert embed_dim % 4 == 0, f"Embedding dimension must be divisible by 4, got {embed_dim}"
|
146
|
+
half_d = embed_dim // 2
|
147
|
+
theta = 1.0 / (base ** (jnp.arange(0, half_d, 2, dtype=jnp.float32) / half_d))
|
148
|
+
seq_idx = jnp.arange(offset, tsz + offset, dtype=jnp.float32)
|
149
|
+
idx_theta_tc = jnp.einsum("t,c->tc", seq_idx, theta)
|
150
|
+
idx_theta2_tc = jnp.concatenate([idx_theta_tc, idx_theta_tc], axis=1)
|
151
|
+
cos_tc, sin_tc = jnp.cos(idx_theta2_tc), jnp.sin(idx_theta2_tc)
|
152
|
+
emb_2tc = jnp.stack((cos_tc, sin_tc), axis=0)
|
153
|
+
return emb_2tc.astype(dtype)
|
154
|
+
|
155
|
+
|
156
|
+
def apply_rotary_embeddings(x_tc: Array, embs_2tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
157
|
+
cos_tc, sin_tc = embs_2tc[0], embs_2tc[1]
|
158
|
+
tsz, embed_dim = x_tc.shape
|
159
|
+
half_d = embed_dim // 2
|
160
|
+
quarter_d = embed_dim // 4
|
161
|
+
x_rope_tc, x_pass_tc = x_tc[..., :half_d], x_tc[..., half_d:]
|
162
|
+
neg_half_x_tc = jnp.concatenate([-x_rope_tc[..., quarter_d:], x_rope_tc[..., :quarter_d]], axis=-1)
|
163
|
+
cos_part_tc = cos_tc[offset : offset + tsz] if times_t is None else cos_tc[times_t]
|
164
|
+
sin_part_tc = sin_tc[offset : offset + tsz] if times_t is None else sin_tc[times_t]
|
165
|
+
x_rope_tc = x_rope_tc * cos_part_tc + neg_half_x_tc * sin_part_tc
|
166
|
+
return jnp.concatenate((x_rope_tc, x_pass_tc), axis=-1)
|
167
|
+
|
168
|
+
|
169
|
+
def rotary_embeddings(x_tc: Array, offset: int = 0, base: int = 10_000) -> Array:
|
170
|
+
"""Defines a single function for applying rotary embeddings.
|
171
|
+
|
172
|
+
This is slower than using the module, but it doesn't require
|
173
|
+
pre-initializing the embeddings, so it can be used when running online.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
x_tc: The input tensor, with shape ``(batch, tsz, embed_dim)``.
|
177
|
+
offset: The offset for the first element.
|
178
|
+
base: The base for the sinusoidal embeddings.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
The input tensor with rotary embeddings applied.
|
182
|
+
"""
|
183
|
+
(tsz, embed_dim), dtype = x_tc.shape, x_tc.dtype
|
184
|
+
emb_2tc = get_rotary_embeddings(tsz + offset, embed_dim, dtype, 0, base)
|
185
|
+
return apply_rotary_embeddings(x_tc, emb_2tc, offset)
|
186
|
+
|
187
|
+
|
188
|
+
class RotaryEmbeddings(eqx.Module):
|
189
|
+
"""Defines a rotary embeddings module.
|
190
|
+
|
191
|
+
Parameters:
|
192
|
+
base: The base for the sinusoidal embeddings.
|
193
|
+
"""
|
194
|
+
|
195
|
+
base: int = eqx.field(static=True)
|
196
|
+
|
197
|
+
def __init__(self, base: int = 10_000) -> None:
|
198
|
+
"""Defines a rotary embeddings module.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
base: The base for the sinusoidal embeddings.
|
202
|
+
"""
|
203
|
+
super().__init__()
|
204
|
+
|
205
|
+
self.base = base
|
206
|
+
|
207
|
+
def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
208
|
+
tsz, embed_dim = x_tc.shape
|
209
|
+
max_tsz = max(tsz, 0 if times_t is None else int(times_t.max().item()) + 1) + offset
|
210
|
+
emb_2tc = get_rotary_embeddings(max_tsz, embed_dim, x_tc.dtype, 0, self.base)
|
211
|
+
return apply_rotary_embeddings(x_tc, emb_2tc, offset, times_t)
|
212
|
+
|
213
|
+
|
214
|
+
@overload
|
215
|
+
def get_positional_embeddings(kind: Literal["identity"]) -> IdentityPositionalEmbeddings: ...
|
216
|
+
|
217
|
+
|
218
|
+
@overload
|
219
|
+
def get_positional_embeddings(
|
220
|
+
kind: Literal["learned"],
|
221
|
+
*,
|
222
|
+
max_tsz: int,
|
223
|
+
embed_dim: int,
|
224
|
+
learnable: bool | None = None,
|
225
|
+
key: PRNGKeyArray,
|
226
|
+
) -> LearnedPositionalEmbeddings: ...
|
227
|
+
|
228
|
+
|
229
|
+
@overload
|
230
|
+
def get_positional_embeddings(
|
231
|
+
kind: Literal["sinusoidal"],
|
232
|
+
*,
|
233
|
+
max_tsz: int | None = None,
|
234
|
+
embed_dim: int | None = None,
|
235
|
+
learnable: bool | None = None,
|
236
|
+
base: int = 10_000,
|
237
|
+
) -> SinusoidalEmbeddings: ...
|
238
|
+
|
239
|
+
|
240
|
+
@overload
|
241
|
+
def get_positional_embeddings(
|
242
|
+
kind: Literal["rotary"],
|
243
|
+
*,
|
244
|
+
base: int = 10_000,
|
245
|
+
) -> RotaryEmbeddings: ...
|
246
|
+
|
247
|
+
|
248
|
+
@overload
|
249
|
+
def get_positional_embeddings(
|
250
|
+
kind: EmbeddingKind,
|
251
|
+
*,
|
252
|
+
max_tsz: int | None = None,
|
253
|
+
embed_dim: int | None = None,
|
254
|
+
learnable: bool | None = None,
|
255
|
+
base: int = 10_000,
|
256
|
+
key: PRNGKeyArray | None = None,
|
257
|
+
) -> IdentityPositionalEmbeddings | LearnedPositionalEmbeddings | SinusoidalEmbeddings | RotaryEmbeddings: ...
|
258
|
+
|
259
|
+
|
260
|
+
def get_positional_embeddings(
|
261
|
+
kind: EmbeddingKind,
|
262
|
+
*,
|
263
|
+
max_tsz: int | None = None,
|
264
|
+
embed_dim: int | None = None,
|
265
|
+
learnable: bool | None = None,
|
266
|
+
base: int = 10_000,
|
267
|
+
key: PRNGKeyArray | None = None,
|
268
|
+
) -> eqx.Module:
|
269
|
+
"""Defines the common module for adding positional embeddings.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
kind: The type of embedding to use.
|
273
|
+
max_tsz: The maximum sequence length.
|
274
|
+
embed_dim: The embedding dimension.
|
275
|
+
learnable: Whether the embeddings are learnable; if not provided,
|
276
|
+
uses sensible defaults.
|
277
|
+
base: The base for the sinusoidal embeddings.
|
278
|
+
key: The PRNG key for initializing learnable embeddings.
|
279
|
+
|
280
|
+
Returns:
|
281
|
+
The positional embeddings module.
|
282
|
+
|
283
|
+
Raises:
|
284
|
+
ValueError: If an invalid embedding kind is supplied.
|
285
|
+
"""
|
286
|
+
match kind:
|
287
|
+
case "identity":
|
288
|
+
return IdentityPositionalEmbeddings()
|
289
|
+
|
290
|
+
case "learned":
|
291
|
+
assert max_tsz is not None, "Learned embeddings require `max_tsz` to be set"
|
292
|
+
assert embed_dim is not None, "Learned embeddings require `embed_dim` to be set"
|
293
|
+
assert key is not None, "Learned embeddings require `key` to be set"
|
294
|
+
|
295
|
+
return LearnedPositionalEmbeddings(
|
296
|
+
max_tsz=max_tsz,
|
297
|
+
embed_dim=embed_dim,
|
298
|
+
learnable=True if learnable is None else learnable,
|
299
|
+
key=key,
|
300
|
+
)
|
301
|
+
|
302
|
+
case "sinusoidal":
|
303
|
+
return SinusoidalEmbeddings(
|
304
|
+
max_tsz=max_tsz,
|
305
|
+
embed_dim=embed_dim,
|
306
|
+
learnable=False if learnable is None else learnable,
|
307
|
+
base=base,
|
308
|
+
)
|
309
|
+
|
310
|
+
case "rotary":
|
311
|
+
return RotaryEmbeddings(base=base)
|
312
|
+
|
313
|
+
case _:
|
314
|
+
raise ValueError(f"Invalid embedding kind: {kind}")
|
315
|
+
|
316
|
+
|
317
|
+
def fourier_embeddings(t: Array, dim: int, max_period: int = 10000) -> Array:
|
318
|
+
half = dim // 2
|
319
|
+
idxs = jnp.arange(start=0, stop=half, dtype=jnp.float32)
|
320
|
+
freqs = jnp.exp(-math.log(max_period) * idxs / half)
|
321
|
+
args = t[:, None].astype(jnp.float32) * freqs[None]
|
322
|
+
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
|
323
|
+
# Adds an additional row of zeros to match the expected dimension.
|
324
|
+
if dim % 2:
|
325
|
+
embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
|
326
|
+
return embedding
|
327
|
+
|
328
|
+
|
329
|
+
class FourierEmbeddings(eqx.Module):
|
330
|
+
"""Defines a module for applying Fourier embeddings to timesteps.
|
331
|
+
|
332
|
+
This module differs from the other positional embedding modules because it
|
333
|
+
expects a continuous time input, rather than a discrete time input.
|
334
|
+
|
335
|
+
Parameters:
|
336
|
+
dim: The number of embedding dimensions. This value is used to determine
|
337
|
+
how many different frequencies to use, and a higher value means
|
338
|
+
higher frequencies.
|
339
|
+
max_period: The maximum period for the embeddings. This should roughly
|
340
|
+
be in line with the maximum number of timesteps; the default value
|
341
|
+
of 10,000 is commonly used in NLP applications, and is derived from
|
342
|
+
operating on sequence lengths of 100 to 1000 tokens.
|
343
|
+
"""
|
344
|
+
|
345
|
+
dim: int
|
346
|
+
max_period: int
|
347
|
+
|
348
|
+
def __init__(self, dim: int, max_period: int = 10000) -> None:
|
349
|
+
super().__init__()
|
350
|
+
|
351
|
+
self.dim = dim
|
352
|
+
self.max_period = max_period
|
353
|
+
|
354
|
+
def __call__(self, t: Array) -> Array:
|
355
|
+
return fourier_embeddings(t, self.dim, self.max_period)
|
xax/nn/functions.py
CHANGED
@@ -5,6 +5,7 @@ import random
|
|
5
5
|
from dataclasses import is_dataclass
|
6
6
|
from typing import Any, Callable, Iterable, Mapping, ParamSpec, Sequence, TypeVar
|
7
7
|
|
8
|
+
import jax.numpy as jnp
|
8
9
|
import numpy as np
|
9
10
|
from jaxtyping import Array
|
10
11
|
|
@@ -14,26 +15,29 @@ T = TypeVar("T")
|
|
14
15
|
P = ParamSpec("P")
|
15
16
|
|
16
17
|
|
17
|
-
def recursive_apply(item: Any, func: Callable[[Array], Array]) -> Any: # noqa: ANN401
|
18
|
+
def recursive_apply(item: Any, func: Callable[[Array], Array], include_numpy: bool = False) -> Any: # noqa: ANN401
|
18
19
|
"""Applies a function recursively to tensors in an item.
|
19
20
|
|
20
21
|
Args:
|
21
22
|
item: The item to apply the function to
|
22
23
|
func: The function to apply (for the tensor)
|
24
|
+
include_numpy: If set, include numpy arrays
|
23
25
|
|
24
26
|
Returns:
|
25
27
|
The same item, with the function applied
|
26
28
|
"""
|
27
29
|
if isinstance(item, (str, int, float)):
|
28
30
|
return item
|
31
|
+
if include_numpy and isinstance(item, np.ndarray):
|
32
|
+
return func(jnp.array(item))
|
29
33
|
if isinstance(item, Array):
|
30
34
|
return func(item)
|
31
35
|
if is_dataclass(item):
|
32
|
-
return item.__class__(**{k: recursive_apply(v, func) for k, v in item.__dict__.items()})
|
36
|
+
return item.__class__(**{k: recursive_apply(v, func, include_numpy) for k, v in item.__dict__.items()})
|
33
37
|
if isinstance(item, Mapping):
|
34
|
-
return {k: recursive_apply(v, func) for k, v in item.items()}
|
38
|
+
return {k: recursive_apply(v, func, include_numpy) for k, v in item.items()}
|
35
39
|
if isinstance(item, Sequence):
|
36
|
-
return [recursive_apply(i, func) for i in item]
|
40
|
+
return [recursive_apply(i, func, include_numpy) for i in item]
|
37
41
|
return item
|
38
42
|
|
39
43
|
|
xax/requirements-dev.txt
CHANGED
xax/requirements.txt
CHANGED
@@ -1,18 +1,25 @@
|
|
1
1
|
# requirements.txt
|
2
2
|
|
3
|
-
|
4
|
-
equinox
|
5
|
-
gitpython
|
3
|
+
# Core ML/JAX dependencies
|
6
4
|
jax
|
7
5
|
jaxtyping
|
8
|
-
|
6
|
+
equinox
|
9
7
|
optax
|
8
|
+
dpshdl
|
9
|
+
chex
|
10
|
+
importlib-resources
|
11
|
+
|
12
|
+
# Data processing and serialization
|
13
|
+
cloudpickle
|
10
14
|
pillow
|
11
|
-
|
12
|
-
|
15
|
+
|
16
|
+
# Configuration and project management
|
17
|
+
omegaconf
|
18
|
+
gitpython
|
19
|
+
|
20
|
+
# Monitoring and logging
|
13
21
|
tensorboard
|
22
|
+
psutil
|
14
23
|
|
15
|
-
#
|
16
|
-
|
17
|
-
types-psutil
|
18
|
-
types-requests
|
24
|
+
# Networking
|
25
|
+
requests
|
xax/task/base.py
CHANGED
@@ -15,6 +15,7 @@ from pathlib import Path
|
|
15
15
|
from types import TracebackType
|
16
16
|
from typing import Generic, Self, TypeVar, cast
|
17
17
|
|
18
|
+
import jax
|
18
19
|
from omegaconf import Container, DictConfig, OmegaConf
|
19
20
|
|
20
21
|
from xax.core.state import State
|
@@ -23,6 +24,7 @@ from xax.utils.text import camelcase_to_snakecase
|
|
23
24
|
logger = logging.getLogger(__name__)
|
24
25
|
|
25
26
|
|
27
|
+
@jax.tree_util.register_dataclass
|
26
28
|
@dataclass
|
27
29
|
class BaseConfig:
|
28
30
|
pass
|
@@ -79,12 +81,6 @@ class BaseTask(Generic[Config]):
|
|
79
81
|
def on_training_end(self, state: State) -> State:
|
80
82
|
return state
|
81
83
|
|
82
|
-
def on_before_save_checkpoint(self, ckpt_path: Path) -> None:
|
83
|
-
pass
|
84
|
-
|
85
|
-
def on_after_save_checkpoint(self, ckpt_path: Path) -> None:
|
86
|
-
pass
|
87
|
-
|
88
84
|
@functools.cached_property
|
89
85
|
def task_class_name(self) -> str:
|
90
86
|
return self.__class__.__name__
|