xax 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +49 -7
- xax/core/conf.py +1 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +15 -10
- xax/task/base.py +0 -6
- xax/task/logger.py +328 -393
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/tensorboard.py +2 -5
- xax/task/mixins/__init__.py +2 -1
- xax/task/mixins/artifacts.py +14 -7
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +10 -10
- xax/task/mixins/data_loader.py +6 -9
- xax/task/mixins/gpu_stats.py +3 -3
- xax/task/mixins/logger.py +2 -250
- xax/task/mixins/process.py +4 -0
- xax/task/mixins/train.py +71 -40
- xax/task/task.py +6 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +49 -29
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/METADATA +15 -14
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.5.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -11,7 +11,7 @@ This file can be maintained by running the update script:
|
|
11
11
|
python -m scripts.update_api --inplace
|
12
12
|
"""
|
13
13
|
|
14
|
-
__version__ = "0.0.
|
14
|
+
__version__ = "0.0.5"
|
15
15
|
|
16
16
|
# This list shouldn't be modified by hand; instead, run the update script.
|
17
17
|
__all__ = [
|
@@ -23,15 +23,25 @@ __all__ = [
|
|
23
23
|
"load_user_config",
|
24
24
|
"State",
|
25
25
|
"cast_phase",
|
26
|
+
"FourierEmbeddings",
|
27
|
+
"IdentityPositionalEmbeddings",
|
28
|
+
"LearnedPositionalEmbeddings",
|
29
|
+
"RotaryEmbeddings",
|
30
|
+
"SinusoidalEmbeddings",
|
31
|
+
"apply_rotary_embeddings",
|
32
|
+
"cast_embedding_kind",
|
33
|
+
"fourier_embeddings",
|
34
|
+
"get_positional_embeddings",
|
35
|
+
"get_rotary_embeddings",
|
36
|
+
"rotary_embeddings",
|
26
37
|
"BaseLauncher",
|
27
38
|
"CliLauncher",
|
28
39
|
"SingleProcessLauncher",
|
29
|
-
"LogAudio",
|
30
40
|
"LogImage",
|
31
41
|
"LogLine",
|
32
|
-
"LogVideo",
|
33
42
|
"Logger",
|
34
43
|
"LoggerImpl",
|
44
|
+
"CallbackLogger",
|
35
45
|
"JsonLogger",
|
36
46
|
"StateLogger",
|
37
47
|
"StdoutLogger",
|
@@ -72,7 +82,10 @@ __all__ = [
|
|
72
82
|
]
|
73
83
|
|
74
84
|
__all__ += [
|
85
|
+
"Batch",
|
75
86
|
"CollateMode",
|
87
|
+
"EmbeddingKind",
|
88
|
+
"Output",
|
76
89
|
"Phase",
|
77
90
|
]
|
78
91
|
|
@@ -95,21 +108,31 @@ NAME_MAP: dict[str, str] = {
|
|
95
108
|
"load_user_config": "core.conf",
|
96
109
|
"State": "core.state",
|
97
110
|
"cast_phase": "core.state",
|
111
|
+
"FourierEmbeddings": "nn.embeddings",
|
112
|
+
"IdentityPositionalEmbeddings": "nn.embeddings",
|
113
|
+
"LearnedPositionalEmbeddings": "nn.embeddings",
|
114
|
+
"RotaryEmbeddings": "nn.embeddings",
|
115
|
+
"SinusoidalEmbeddings": "nn.embeddings",
|
116
|
+
"apply_rotary_embeddings": "nn.embeddings",
|
117
|
+
"cast_embedding_kind": "nn.embeddings",
|
118
|
+
"fourier_embeddings": "nn.embeddings",
|
119
|
+
"get_positional_embeddings": "nn.embeddings",
|
120
|
+
"get_rotary_embeddings": "nn.embeddings",
|
121
|
+
"rotary_embeddings": "nn.embeddings",
|
98
122
|
"BaseLauncher": "task.launchers.base",
|
99
123
|
"CliLauncher": "task.launchers.cli",
|
100
124
|
"SingleProcessLauncher": "task.launchers.single_process",
|
101
|
-
"LogAudio": "task.logger",
|
102
125
|
"LogImage": "task.logger",
|
103
126
|
"LogLine": "task.logger",
|
104
|
-
"LogVideo": "task.logger",
|
105
127
|
"Logger": "task.logger",
|
106
128
|
"LoggerImpl": "task.logger",
|
129
|
+
"CallbackLogger": "task.loggers.callback",
|
107
130
|
"JsonLogger": "task.loggers.json",
|
108
131
|
"StateLogger": "task.loggers.state",
|
109
132
|
"StdoutLogger": "task.loggers.stdout",
|
110
133
|
"TensorboardLogger": "task.loggers.tensorboard",
|
111
134
|
"CPUStatsOptions": "task.mixins.cpu_stats",
|
112
|
-
"
|
135
|
+
"DataloaderConfig": "task.mixins.data_loader",
|
113
136
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
114
137
|
"Script": "task.script",
|
115
138
|
"ScriptConfig": "task.script",
|
@@ -146,7 +169,10 @@ NAME_MAP: dict[str, str] = {
|
|
146
169
|
# Need to manually set some values which can't be auto-generated.
|
147
170
|
NAME_MAP.update(
|
148
171
|
{
|
172
|
+
"Batch": "task.mixins.train",
|
149
173
|
"CollateMode": "utils.data.collate",
|
174
|
+
"EmbeddingKind": "nn.embeddings",
|
175
|
+
"Output": "task.mixins.output",
|
150
176
|
"Phase": "core.state",
|
151
177
|
},
|
152
178
|
)
|
@@ -171,10 +197,25 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
171
197
|
load_user_config,
|
172
198
|
)
|
173
199
|
from xax.core.state import Phase, State, cast_phase
|
200
|
+
from xax.nn.embeddings import (
|
201
|
+
EmbeddingKind,
|
202
|
+
FourierEmbeddings,
|
203
|
+
IdentityPositionalEmbeddings,
|
204
|
+
LearnedPositionalEmbeddings,
|
205
|
+
RotaryEmbeddings,
|
206
|
+
SinusoidalEmbeddings,
|
207
|
+
apply_rotary_embeddings,
|
208
|
+
cast_embedding_kind,
|
209
|
+
fourier_embeddings,
|
210
|
+
get_positional_embeddings,
|
211
|
+
get_rotary_embeddings,
|
212
|
+
rotary_embeddings,
|
213
|
+
)
|
174
214
|
from xax.task.launchers.base import BaseLauncher
|
175
215
|
from xax.task.launchers.cli import CliLauncher
|
176
216
|
from xax.task.launchers.single_process import SingleProcessLauncher
|
177
|
-
from xax.task.logger import
|
217
|
+
from xax.task.logger import Logger, LoggerImpl, LogImage, LogLine
|
218
|
+
from xax.task.loggers.callback import CallbackLogger
|
178
219
|
from xax.task.loggers.json import JsonLogger
|
179
220
|
from xax.task.loggers.state import StateLogger
|
180
221
|
from xax.task.loggers.stdout import StdoutLogger
|
@@ -182,6 +223,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
182
223
|
from xax.task.mixins.cpu_stats import CPUStatsOptions
|
183
224
|
from xax.task.mixins.data_loader import DataloaderConfig
|
184
225
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
226
|
+
from xax.task.mixins.train import Batch, Output
|
185
227
|
from xax.task.script import Script, ScriptConfig
|
186
228
|
from xax.task.task import Config, Task
|
187
229
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
xax/core/conf.py
CHANGED
xax/nn/embeddings.py
ADDED
@@ -0,0 +1,355 @@
|
|
1
|
+
"""Defines embedding layers."""
|
2
|
+
|
3
|
+
import math
|
4
|
+
from typing import Literal, cast, get_args, overload
|
5
|
+
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import jax.random as jrandom
|
10
|
+
from jaxtyping import Array, DTypeLike, PRNGKeyArray
|
11
|
+
|
12
|
+
EmbeddingKind = Literal["identity", "learned", "sinusoidal", "rotary"]
|
13
|
+
|
14
|
+
|
15
|
+
def cast_embedding_kind(k: str) -> EmbeddingKind:
|
16
|
+
args = get_args(EmbeddingKind)
|
17
|
+
assert k in args, f"Invalid initialization type: '{k}' Valid options are {args}"
|
18
|
+
return cast(EmbeddingKind, k)
|
19
|
+
|
20
|
+
|
21
|
+
class IdentityPositionalEmbeddings(eqx.Module):
|
22
|
+
def __call__(self, x: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
23
|
+
return x
|
24
|
+
|
25
|
+
|
26
|
+
class LearnedPositionalEmbeddings(eqx.Module):
|
27
|
+
"""Defines a learned embeddings module.
|
28
|
+
|
29
|
+
Parameters:
|
30
|
+
max_tsz: The maximum sequence length.
|
31
|
+
embed_dim: The embedding dimension.
|
32
|
+
weight_init: The initialization type for the embedding weight.
|
33
|
+
learnable: Whether the embeddings are learnable.
|
34
|
+
"""
|
35
|
+
|
36
|
+
max_tsz: int = eqx.field(static=True)
|
37
|
+
embed_dim: int = eqx.field(static=True)
|
38
|
+
learnable: bool = eqx.field(static=True)
|
39
|
+
embeddings_tc: Array
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
max_tsz: int,
|
44
|
+
embed_dim: int,
|
45
|
+
learnable: bool = True,
|
46
|
+
*,
|
47
|
+
key: PRNGKeyArray,
|
48
|
+
) -> None:
|
49
|
+
super().__init__()
|
50
|
+
|
51
|
+
self.max_tsz = max_tsz
|
52
|
+
self.embed_dim = embed_dim
|
53
|
+
self.learnable = learnable
|
54
|
+
|
55
|
+
self.embeddings_tc = jrandom.normal(key, (max_tsz, embed_dim))
|
56
|
+
|
57
|
+
def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
58
|
+
if times_t is None:
|
59
|
+
emb_tc = self.embeddings_tc[offset : offset + x_tc.shape[-2]]
|
60
|
+
else:
|
61
|
+
emb_tc = self.embeddings_tc[times_t]
|
62
|
+
if not self.learnable:
|
63
|
+
emb_tc = jax.lax.stop_gradient(emb_tc)
|
64
|
+
return emb_tc
|
65
|
+
|
66
|
+
|
67
|
+
class SinusoidalEmbeddings(eqx.Module):
|
68
|
+
"""Defines a sinusoidal embeddings module.
|
69
|
+
|
70
|
+
Parameters:
|
71
|
+
embed_dim: The embedding dimension.
|
72
|
+
max_tsz: The maximum sequence length.
|
73
|
+
learnable: Whether the embeddings are learnable.
|
74
|
+
base: The base for the sinusoidal embeddings.
|
75
|
+
"""
|
76
|
+
|
77
|
+
base: int = eqx.field(static=True)
|
78
|
+
max_tsz: int | None = eqx.field(static=True)
|
79
|
+
embed_dim: int | None = eqx.field(static=True)
|
80
|
+
embeddings_tc: Array | None
|
81
|
+
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
embed_dim: int | None = None,
|
85
|
+
max_tsz: int | None = None,
|
86
|
+
learnable: bool = True,
|
87
|
+
base: int = 10_000,
|
88
|
+
) -> None:
|
89
|
+
super().__init__()
|
90
|
+
|
91
|
+
self.max_tsz = max_tsz
|
92
|
+
self.embed_dim = embed_dim
|
93
|
+
self.base = base
|
94
|
+
|
95
|
+
self.embeddings_tc: Array | None = None
|
96
|
+
if learnable:
|
97
|
+
assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
|
98
|
+
assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
|
99
|
+
self.embeddings_tc = self.get_embeddings(max_tsz, embed_dim)
|
100
|
+
|
101
|
+
def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
102
|
+
tsz, dims = x_tc.shape
|
103
|
+
|
104
|
+
# If the embeddings are learnable, use the property.
|
105
|
+
if self.embeddings_tc is None:
|
106
|
+
if times_t is None:
|
107
|
+
embeddings_tc = self.get_embeddings(offset + tsz, dims, x_tc.dtype)
|
108
|
+
else:
|
109
|
+
embeddings_tc = self.get_embeddings(times_t.max().item() + 1, dims, x_tc.dtype)
|
110
|
+
else:
|
111
|
+
embeddings_tc = self.embeddings_tc
|
112
|
+
|
113
|
+
# Get only the embeddings for the specified time steps.
|
114
|
+
if times_t is None:
|
115
|
+
embeddings_tc = embeddings_tc[offset : offset + tsz]
|
116
|
+
else:
|
117
|
+
embeddings_tc = embeddings_tc[times_t]
|
118
|
+
|
119
|
+
return x_tc + embeddings_tc
|
120
|
+
|
121
|
+
def get_embeddings(
|
122
|
+
self,
|
123
|
+
tsz: int,
|
124
|
+
embed_dim: int,
|
125
|
+
dtype: DTypeLike | None = None,
|
126
|
+
) -> Array:
|
127
|
+
positions_t = jax.numpy.arange(tsz, dtype=dtype)
|
128
|
+
dim_d = jax.numpy.arange(embed_dim, dtype=dtype)
|
129
|
+
dim_d = self.base ** (2 * (dim_d // 2) / embed_dim)
|
130
|
+
embeddings_td = positions_t[:, None] / dim_d[None, :]
|
131
|
+
embeddings_td = jnp.concatenate(
|
132
|
+
[jax.numpy.sin(embeddings_td[:, 0::2]), jax.numpy.cos(embeddings_td[:, 1::2])],
|
133
|
+
axis=-1,
|
134
|
+
)
|
135
|
+
return embeddings_td.astype(dtype)
|
136
|
+
|
137
|
+
|
138
|
+
def get_rotary_embeddings(
|
139
|
+
tsz: int,
|
140
|
+
embed_dim: int,
|
141
|
+
dtype: jnp.dtype,
|
142
|
+
offset: int = 0,
|
143
|
+
base: int = 10_000,
|
144
|
+
) -> Array:
|
145
|
+
assert embed_dim % 4 == 0, f"Embedding dimension must be divisible by 4, got {embed_dim}"
|
146
|
+
half_d = embed_dim // 2
|
147
|
+
theta = 1.0 / (base ** (jnp.arange(0, half_d, 2, dtype=jnp.float32) / half_d))
|
148
|
+
seq_idx = jnp.arange(offset, tsz + offset, dtype=jnp.float32)
|
149
|
+
idx_theta_tc = jnp.einsum("t,c->tc", seq_idx, theta)
|
150
|
+
idx_theta2_tc = jnp.concatenate([idx_theta_tc, idx_theta_tc], axis=1)
|
151
|
+
cos_tc, sin_tc = jnp.cos(idx_theta2_tc), jnp.sin(idx_theta2_tc)
|
152
|
+
emb_2tc = jnp.stack((cos_tc, sin_tc), axis=0)
|
153
|
+
return emb_2tc.astype(dtype)
|
154
|
+
|
155
|
+
|
156
|
+
def apply_rotary_embeddings(x_tc: Array, embs_2tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
157
|
+
cos_tc, sin_tc = embs_2tc[0], embs_2tc[1]
|
158
|
+
tsz, embed_dim = x_tc.shape
|
159
|
+
half_d = embed_dim // 2
|
160
|
+
quarter_d = embed_dim // 4
|
161
|
+
x_rope_tc, x_pass_tc = x_tc[..., :half_d], x_tc[..., half_d:]
|
162
|
+
neg_half_x_tc = jnp.concatenate([-x_rope_tc[..., quarter_d:], x_rope_tc[..., :quarter_d]], axis=-1)
|
163
|
+
cos_part_tc = cos_tc[offset : offset + tsz] if times_t is None else cos_tc[times_t]
|
164
|
+
sin_part_tc = sin_tc[offset : offset + tsz] if times_t is None else sin_tc[times_t]
|
165
|
+
x_rope_tc = x_rope_tc * cos_part_tc + neg_half_x_tc * sin_part_tc
|
166
|
+
return jnp.concatenate((x_rope_tc, x_pass_tc), axis=-1)
|
167
|
+
|
168
|
+
|
169
|
+
def rotary_embeddings(x_tc: Array, offset: int = 0, base: int = 10_000) -> Array:
|
170
|
+
"""Defines a single function for applying rotary embeddings.
|
171
|
+
|
172
|
+
This is slower than using the module, but it doesn't require
|
173
|
+
pre-initializing the embeddings, so it can be used when running online.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
x_tc: The input tensor, with shape ``(batch, tsz, embed_dim)``.
|
177
|
+
offset: The offset for the first element.
|
178
|
+
base: The base for the sinusoidal embeddings.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
The input tensor with rotary embeddings applied.
|
182
|
+
"""
|
183
|
+
(tsz, embed_dim), dtype = x_tc.shape, x_tc.dtype
|
184
|
+
emb_2tc = get_rotary_embeddings(tsz + offset, embed_dim, dtype, 0, base)
|
185
|
+
return apply_rotary_embeddings(x_tc, emb_2tc, offset)
|
186
|
+
|
187
|
+
|
188
|
+
class RotaryEmbeddings(eqx.Module):
|
189
|
+
"""Defines a rotary embeddings module.
|
190
|
+
|
191
|
+
Parameters:
|
192
|
+
base: The base for the sinusoidal embeddings.
|
193
|
+
"""
|
194
|
+
|
195
|
+
base: int = eqx.field(static=True)
|
196
|
+
|
197
|
+
def __init__(self, base: int = 10_000) -> None:
|
198
|
+
"""Defines a rotary embeddings module.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
base: The base for the sinusoidal embeddings.
|
202
|
+
"""
|
203
|
+
super().__init__()
|
204
|
+
|
205
|
+
self.base = base
|
206
|
+
|
207
|
+
def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
208
|
+
tsz, embed_dim = x_tc.shape
|
209
|
+
max_tsz = max(tsz, 0 if times_t is None else int(times_t.max().item()) + 1) + offset
|
210
|
+
emb_2tc = get_rotary_embeddings(max_tsz, embed_dim, x_tc.dtype, 0, self.base)
|
211
|
+
return apply_rotary_embeddings(x_tc, emb_2tc, offset, times_t)
|
212
|
+
|
213
|
+
|
214
|
+
@overload
|
215
|
+
def get_positional_embeddings(kind: Literal["identity"]) -> IdentityPositionalEmbeddings: ...
|
216
|
+
|
217
|
+
|
218
|
+
@overload
|
219
|
+
def get_positional_embeddings(
|
220
|
+
kind: Literal["learned"],
|
221
|
+
*,
|
222
|
+
max_tsz: int,
|
223
|
+
embed_dim: int,
|
224
|
+
learnable: bool | None = None,
|
225
|
+
key: PRNGKeyArray,
|
226
|
+
) -> LearnedPositionalEmbeddings: ...
|
227
|
+
|
228
|
+
|
229
|
+
@overload
|
230
|
+
def get_positional_embeddings(
|
231
|
+
kind: Literal["sinusoidal"],
|
232
|
+
*,
|
233
|
+
max_tsz: int | None = None,
|
234
|
+
embed_dim: int | None = None,
|
235
|
+
learnable: bool | None = None,
|
236
|
+
base: int = 10_000,
|
237
|
+
) -> SinusoidalEmbeddings: ...
|
238
|
+
|
239
|
+
|
240
|
+
@overload
|
241
|
+
def get_positional_embeddings(
|
242
|
+
kind: Literal["rotary"],
|
243
|
+
*,
|
244
|
+
base: int = 10_000,
|
245
|
+
) -> RotaryEmbeddings: ...
|
246
|
+
|
247
|
+
|
248
|
+
@overload
|
249
|
+
def get_positional_embeddings(
|
250
|
+
kind: EmbeddingKind,
|
251
|
+
*,
|
252
|
+
max_tsz: int | None = None,
|
253
|
+
embed_dim: int | None = None,
|
254
|
+
learnable: bool | None = None,
|
255
|
+
base: int = 10_000,
|
256
|
+
key: PRNGKeyArray | None = None,
|
257
|
+
) -> IdentityPositionalEmbeddings | LearnedPositionalEmbeddings | SinusoidalEmbeddings | RotaryEmbeddings: ...
|
258
|
+
|
259
|
+
|
260
|
+
def get_positional_embeddings(
|
261
|
+
kind: EmbeddingKind,
|
262
|
+
*,
|
263
|
+
max_tsz: int | None = None,
|
264
|
+
embed_dim: int | None = None,
|
265
|
+
learnable: bool | None = None,
|
266
|
+
base: int = 10_000,
|
267
|
+
key: PRNGKeyArray | None = None,
|
268
|
+
) -> eqx.Module:
|
269
|
+
"""Defines the common module for adding positional embeddings.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
kind: The type of embedding to use.
|
273
|
+
max_tsz: The maximum sequence length.
|
274
|
+
embed_dim: The embedding dimension.
|
275
|
+
learnable: Whether the embeddings are learnable; if not provided,
|
276
|
+
uses sensible defaults.
|
277
|
+
base: The base for the sinusoidal embeddings.
|
278
|
+
key: The PRNG key for initializing learnable embeddings.
|
279
|
+
|
280
|
+
Returns:
|
281
|
+
The positional embeddings module.
|
282
|
+
|
283
|
+
Raises:
|
284
|
+
ValueError: If an invalid embedding kind is supplied.
|
285
|
+
"""
|
286
|
+
match kind:
|
287
|
+
case "identity":
|
288
|
+
return IdentityPositionalEmbeddings()
|
289
|
+
|
290
|
+
case "learned":
|
291
|
+
assert max_tsz is not None, "Learned embeddings require `max_tsz` to be set"
|
292
|
+
assert embed_dim is not None, "Learned embeddings require `embed_dim` to be set"
|
293
|
+
assert key is not None, "Learned embeddings require `key` to be set"
|
294
|
+
|
295
|
+
return LearnedPositionalEmbeddings(
|
296
|
+
max_tsz=max_tsz,
|
297
|
+
embed_dim=embed_dim,
|
298
|
+
learnable=True if learnable is None else learnable,
|
299
|
+
key=key,
|
300
|
+
)
|
301
|
+
|
302
|
+
case "sinusoidal":
|
303
|
+
return SinusoidalEmbeddings(
|
304
|
+
max_tsz=max_tsz,
|
305
|
+
embed_dim=embed_dim,
|
306
|
+
learnable=False if learnable is None else learnable,
|
307
|
+
base=base,
|
308
|
+
)
|
309
|
+
|
310
|
+
case "rotary":
|
311
|
+
return RotaryEmbeddings(base=base)
|
312
|
+
|
313
|
+
case _:
|
314
|
+
raise ValueError(f"Invalid embedding kind: {kind}")
|
315
|
+
|
316
|
+
|
317
|
+
def fourier_embeddings(t: Array, dim: int, max_period: int = 10000) -> Array:
|
318
|
+
half = dim // 2
|
319
|
+
idxs = jnp.arange(start=0, stop=half, dtype=jnp.float32)
|
320
|
+
freqs = jnp.exp(-math.log(max_period) * idxs / half)
|
321
|
+
args = t[:, None].astype(jnp.float32) * freqs[None]
|
322
|
+
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
|
323
|
+
# Adds an additional row of zeros to match the expected dimension.
|
324
|
+
if dim % 2:
|
325
|
+
embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
|
326
|
+
return embedding
|
327
|
+
|
328
|
+
|
329
|
+
class FourierEmbeddings(eqx.Module):
|
330
|
+
"""Defines a module for applying Fourier embeddings to timesteps.
|
331
|
+
|
332
|
+
This module differs from the other positional embedding modules because it
|
333
|
+
expects a continuous time input, rather than a discrete time input.
|
334
|
+
|
335
|
+
Parameters:
|
336
|
+
dim: The number of embedding dimensions. This value is used to determine
|
337
|
+
how many different frequencies to use, and a higher value means
|
338
|
+
higher frequencies.
|
339
|
+
max_period: The maximum period for the embeddings. This should roughly
|
340
|
+
be in line with the maximum number of timesteps; the default value
|
341
|
+
of 10,000 is commonly used in NLP applications, and is derived from
|
342
|
+
operating on sequence lengths of 100 to 1000 tokens.
|
343
|
+
"""
|
344
|
+
|
345
|
+
dim: int
|
346
|
+
max_period: int
|
347
|
+
|
348
|
+
def __init__(self, dim: int, max_period: int = 10000) -> None:
|
349
|
+
super().__init__()
|
350
|
+
|
351
|
+
self.dim = dim
|
352
|
+
self.max_period = max_period
|
353
|
+
|
354
|
+
def __call__(self, t: Array) -> Array:
|
355
|
+
return fourier_embeddings(t, self.dim, self.max_period)
|
xax/nn/functions.py
CHANGED
@@ -5,6 +5,7 @@ import random
|
|
5
5
|
from dataclasses import is_dataclass
|
6
6
|
from typing import Any, Callable, Iterable, Mapping, ParamSpec, Sequence, TypeVar
|
7
7
|
|
8
|
+
import jax.numpy as jnp
|
8
9
|
import numpy as np
|
9
10
|
from jaxtyping import Array
|
10
11
|
|
@@ -14,26 +15,29 @@ T = TypeVar("T")
|
|
14
15
|
P = ParamSpec("P")
|
15
16
|
|
16
17
|
|
17
|
-
def recursive_apply(item: Any, func: Callable[[Array], Array]) -> Any: # noqa: ANN401
|
18
|
+
def recursive_apply(item: Any, func: Callable[[Array], Array], include_numpy: bool = False) -> Any: # noqa: ANN401
|
18
19
|
"""Applies a function recursively to tensors in an item.
|
19
20
|
|
20
21
|
Args:
|
21
22
|
item: The item to apply the function to
|
22
23
|
func: The function to apply (for the tensor)
|
24
|
+
include_numpy: If set, include numpy arrays
|
23
25
|
|
24
26
|
Returns:
|
25
27
|
The same item, with the function applied
|
26
28
|
"""
|
27
29
|
if isinstance(item, (str, int, float)):
|
28
30
|
return item
|
31
|
+
if include_numpy and isinstance(item, np.ndarray):
|
32
|
+
return func(jnp.array(item))
|
29
33
|
if isinstance(item, Array):
|
30
34
|
return func(item)
|
31
35
|
if is_dataclass(item):
|
32
|
-
return item.__class__(**{k: recursive_apply(v, func) for k, v in item.__dict__.items()})
|
36
|
+
return item.__class__(**{k: recursive_apply(v, func, include_numpy) for k, v in item.__dict__.items()})
|
33
37
|
if isinstance(item, Mapping):
|
34
|
-
return {k: recursive_apply(v, func) for k, v in item.items()}
|
38
|
+
return {k: recursive_apply(v, func, include_numpy) for k, v in item.items()}
|
35
39
|
if isinstance(item, Sequence):
|
36
|
-
return [recursive_apply(i, func) for i in item]
|
40
|
+
return [recursive_apply(i, func, include_numpy) for i in item]
|
37
41
|
return item
|
38
42
|
|
39
43
|
|
xax/requirements-dev.txt
CHANGED
xax/requirements.txt
CHANGED
@@ -1,18 +1,23 @@
|
|
1
1
|
# requirements.txt
|
2
2
|
|
3
|
-
|
4
|
-
equinox
|
5
|
-
gitpython
|
3
|
+
# Core ML/JAX dependencies
|
6
4
|
jax
|
7
5
|
jaxtyping
|
8
|
-
|
6
|
+
equinox
|
9
7
|
optax
|
8
|
+
dpshdl
|
9
|
+
|
10
|
+
# Data processing and serialization
|
11
|
+
cloudpickle
|
10
12
|
pillow
|
11
|
-
|
12
|
-
|
13
|
+
|
14
|
+
# Configuration and project management
|
15
|
+
omegaconf
|
16
|
+
gitpython
|
17
|
+
|
18
|
+
# Monitoring and logging
|
13
19
|
tensorboard
|
20
|
+
psutil
|
14
21
|
|
15
|
-
#
|
16
|
-
|
17
|
-
types-psutil
|
18
|
-
types-requests
|
22
|
+
# Networking
|
23
|
+
requests
|
xax/task/base.py
CHANGED
@@ -79,12 +79,6 @@ class BaseTask(Generic[Config]):
|
|
79
79
|
def on_training_end(self, state: State) -> State:
|
80
80
|
return state
|
81
81
|
|
82
|
-
def on_before_save_checkpoint(self, ckpt_path: Path) -> None:
|
83
|
-
pass
|
84
|
-
|
85
|
-
def on_after_save_checkpoint(self, ckpt_path: Path) -> None:
|
86
|
-
pass
|
87
|
-
|
88
82
|
@functools.cached_property
|
89
83
|
def task_class_name(self) -> str:
|
90
84
|
return self.__class__.__name__
|