xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/nn/embeddings.py ADDED
@@ -0,0 +1,355 @@
1
+ """Defines embedding layers."""
2
+
3
+ import math
4
+ from typing import Literal, cast, get_args, overload
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jax.random as jrandom
10
+ from jaxtyping import Array, DTypeLike, PRNGKeyArray
11
+
12
+ EmbeddingKind = Literal["identity", "learned", "sinusoidal", "rotary"]
13
+
14
+
15
+ def cast_embedding_kind(k: str) -> EmbeddingKind:
16
+ args = get_args(EmbeddingKind)
17
+ assert k in args, f"Invalid initialization type: '{k}' Valid options are {args}"
18
+ return cast(EmbeddingKind, k)
19
+
20
+
21
+ class IdentityPositionalEmbeddings(eqx.Module):
22
+ def __call__(self, x: Array, offset: int = 0, times_t: Array | None = None) -> Array:
23
+ return x
24
+
25
+
26
+ class LearnedPositionalEmbeddings(eqx.Module):
27
+ """Defines a learned embeddings module.
28
+
29
+ Parameters:
30
+ max_tsz: The maximum sequence length.
31
+ embed_dim: The embedding dimension.
32
+ weight_init: The initialization type for the embedding weight.
33
+ learnable: Whether the embeddings are learnable.
34
+ """
35
+
36
+ max_tsz: int = eqx.field(static=True)
37
+ embed_dim: int = eqx.field(static=True)
38
+ learnable: bool = eqx.field(static=True)
39
+ embeddings_tc: Array
40
+
41
+ def __init__(
42
+ self,
43
+ max_tsz: int,
44
+ embed_dim: int,
45
+ learnable: bool = True,
46
+ *,
47
+ key: PRNGKeyArray,
48
+ ) -> None:
49
+ super().__init__()
50
+
51
+ self.max_tsz = max_tsz
52
+ self.embed_dim = embed_dim
53
+ self.learnable = learnable
54
+
55
+ self.embeddings_tc = jrandom.normal(key, (max_tsz, embed_dim))
56
+
57
+ def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
58
+ if times_t is None:
59
+ emb_tc = self.embeddings_tc[offset : offset + x_tc.shape[-2]]
60
+ else:
61
+ emb_tc = self.embeddings_tc[times_t]
62
+ if not self.learnable:
63
+ emb_tc = jax.lax.stop_gradient(emb_tc)
64
+ return emb_tc
65
+
66
+
67
+ class SinusoidalEmbeddings(eqx.Module):
68
+ """Defines a sinusoidal embeddings module.
69
+
70
+ Parameters:
71
+ embed_dim: The embedding dimension.
72
+ max_tsz: The maximum sequence length.
73
+ learnable: Whether the embeddings are learnable.
74
+ base: The base for the sinusoidal embeddings.
75
+ """
76
+
77
+ base: int = eqx.field(static=True)
78
+ max_tsz: int | None = eqx.field(static=True)
79
+ embed_dim: int | None = eqx.field(static=True)
80
+ embeddings_tc: Array | None
81
+
82
+ def __init__(
83
+ self,
84
+ embed_dim: int | None = None,
85
+ max_tsz: int | None = None,
86
+ learnable: bool = True,
87
+ base: int = 10_000,
88
+ ) -> None:
89
+ super().__init__()
90
+
91
+ self.max_tsz = max_tsz
92
+ self.embed_dim = embed_dim
93
+ self.base = base
94
+
95
+ self.embeddings_tc: Array | None = None
96
+ if learnable:
97
+ assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
98
+ assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
99
+ self.embeddings_tc = self.get_embeddings(max_tsz, embed_dim)
100
+
101
+ def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
102
+ tsz, dims = x_tc.shape
103
+
104
+ # If the embeddings are learnable, use the property.
105
+ if self.embeddings_tc is None:
106
+ if times_t is None:
107
+ embeddings_tc = self.get_embeddings(offset + tsz, dims, x_tc.dtype)
108
+ else:
109
+ embeddings_tc = self.get_embeddings(times_t.max().item() + 1, dims, x_tc.dtype)
110
+ else:
111
+ embeddings_tc = self.embeddings_tc
112
+
113
+ # Get only the embeddings for the specified time steps.
114
+ if times_t is None:
115
+ embeddings_tc = embeddings_tc[offset : offset + tsz]
116
+ else:
117
+ embeddings_tc = embeddings_tc[times_t]
118
+
119
+ return x_tc + embeddings_tc
120
+
121
+ def get_embeddings(
122
+ self,
123
+ tsz: int,
124
+ embed_dim: int,
125
+ dtype: DTypeLike | None = None,
126
+ ) -> Array:
127
+ positions_t = jax.numpy.arange(tsz, dtype=dtype)
128
+ dim_d = jax.numpy.arange(embed_dim, dtype=dtype)
129
+ dim_d = self.base ** (2 * (dim_d // 2) / embed_dim)
130
+ embeddings_td = positions_t[:, None] / dim_d[None, :]
131
+ embeddings_td = jnp.concatenate(
132
+ [jax.numpy.sin(embeddings_td[:, 0::2]), jax.numpy.cos(embeddings_td[:, 1::2])],
133
+ axis=-1,
134
+ )
135
+ return embeddings_td.astype(dtype)
136
+
137
+
138
+ def get_rotary_embeddings(
139
+ tsz: int,
140
+ embed_dim: int,
141
+ dtype: jnp.dtype,
142
+ offset: int = 0,
143
+ base: int = 10_000,
144
+ ) -> Array:
145
+ assert embed_dim % 4 == 0, f"Embedding dimension must be divisible by 4, got {embed_dim}"
146
+ half_d = embed_dim // 2
147
+ theta = 1.0 / (base ** (jnp.arange(0, half_d, 2, dtype=jnp.float32) / half_d))
148
+ seq_idx = jnp.arange(offset, tsz + offset, dtype=jnp.float32)
149
+ idx_theta_tc = jnp.einsum("t,c->tc", seq_idx, theta)
150
+ idx_theta2_tc = jnp.concatenate([idx_theta_tc, idx_theta_tc], axis=1)
151
+ cos_tc, sin_tc = jnp.cos(idx_theta2_tc), jnp.sin(idx_theta2_tc)
152
+ emb_2tc = jnp.stack((cos_tc, sin_tc), axis=0)
153
+ return emb_2tc.astype(dtype)
154
+
155
+
156
+ def apply_rotary_embeddings(x_tc: Array, embs_2tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
157
+ cos_tc, sin_tc = embs_2tc[0], embs_2tc[1]
158
+ tsz, embed_dim = x_tc.shape
159
+ half_d = embed_dim // 2
160
+ quarter_d = embed_dim // 4
161
+ x_rope_tc, x_pass_tc = x_tc[..., :half_d], x_tc[..., half_d:]
162
+ neg_half_x_tc = jnp.concatenate([-x_rope_tc[..., quarter_d:], x_rope_tc[..., :quarter_d]], axis=-1)
163
+ cos_part_tc = cos_tc[offset : offset + tsz] if times_t is None else cos_tc[times_t]
164
+ sin_part_tc = sin_tc[offset : offset + tsz] if times_t is None else sin_tc[times_t]
165
+ x_rope_tc = x_rope_tc * cos_part_tc + neg_half_x_tc * sin_part_tc
166
+ return jnp.concatenate((x_rope_tc, x_pass_tc), axis=-1)
167
+
168
+
169
+ def rotary_embeddings(x_tc: Array, offset: int = 0, base: int = 10_000) -> Array:
170
+ """Defines a single function for applying rotary embeddings.
171
+
172
+ This is slower than using the module, but it doesn't require
173
+ pre-initializing the embeddings, so it can be used when running online.
174
+
175
+ Args:
176
+ x_tc: The input tensor, with shape ``(batch, tsz, embed_dim)``.
177
+ offset: The offset for the first element.
178
+ base: The base for the sinusoidal embeddings.
179
+
180
+ Returns:
181
+ The input tensor with rotary embeddings applied.
182
+ """
183
+ (tsz, embed_dim), dtype = x_tc.shape, x_tc.dtype
184
+ emb_2tc = get_rotary_embeddings(tsz + offset, embed_dim, dtype, 0, base)
185
+ return apply_rotary_embeddings(x_tc, emb_2tc, offset)
186
+
187
+
188
+ class RotaryEmbeddings(eqx.Module):
189
+ """Defines a rotary embeddings module.
190
+
191
+ Parameters:
192
+ base: The base for the sinusoidal embeddings.
193
+ """
194
+
195
+ base: int = eqx.field(static=True)
196
+
197
+ def __init__(self, base: int = 10_000) -> None:
198
+ """Defines a rotary embeddings module.
199
+
200
+ Args:
201
+ base: The base for the sinusoidal embeddings.
202
+ """
203
+ super().__init__()
204
+
205
+ self.base = base
206
+
207
+ def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
208
+ tsz, embed_dim = x_tc.shape
209
+ max_tsz = max(tsz, 0 if times_t is None else int(times_t.max().item()) + 1) + offset
210
+ emb_2tc = get_rotary_embeddings(max_tsz, embed_dim, x_tc.dtype, 0, self.base)
211
+ return apply_rotary_embeddings(x_tc, emb_2tc, offset, times_t)
212
+
213
+
214
+ @overload
215
+ def get_positional_embeddings(kind: Literal["identity"]) -> IdentityPositionalEmbeddings: ...
216
+
217
+
218
+ @overload
219
+ def get_positional_embeddings(
220
+ kind: Literal["learned"],
221
+ *,
222
+ max_tsz: int,
223
+ embed_dim: int,
224
+ learnable: bool | None = None,
225
+ key: PRNGKeyArray,
226
+ ) -> LearnedPositionalEmbeddings: ...
227
+
228
+
229
+ @overload
230
+ def get_positional_embeddings(
231
+ kind: Literal["sinusoidal"],
232
+ *,
233
+ max_tsz: int | None = None,
234
+ embed_dim: int | None = None,
235
+ learnable: bool | None = None,
236
+ base: int = 10_000,
237
+ ) -> SinusoidalEmbeddings: ...
238
+
239
+
240
+ @overload
241
+ def get_positional_embeddings(
242
+ kind: Literal["rotary"],
243
+ *,
244
+ base: int = 10_000,
245
+ ) -> RotaryEmbeddings: ...
246
+
247
+
248
+ @overload
249
+ def get_positional_embeddings(
250
+ kind: EmbeddingKind,
251
+ *,
252
+ max_tsz: int | None = None,
253
+ embed_dim: int | None = None,
254
+ learnable: bool | None = None,
255
+ base: int = 10_000,
256
+ key: PRNGKeyArray | None = None,
257
+ ) -> IdentityPositionalEmbeddings | LearnedPositionalEmbeddings | SinusoidalEmbeddings | RotaryEmbeddings: ...
258
+
259
+
260
+ def get_positional_embeddings(
261
+ kind: EmbeddingKind,
262
+ *,
263
+ max_tsz: int | None = None,
264
+ embed_dim: int | None = None,
265
+ learnable: bool | None = None,
266
+ base: int = 10_000,
267
+ key: PRNGKeyArray | None = None,
268
+ ) -> eqx.Module:
269
+ """Defines the common module for adding positional embeddings.
270
+
271
+ Args:
272
+ kind: The type of embedding to use.
273
+ max_tsz: The maximum sequence length.
274
+ embed_dim: The embedding dimension.
275
+ learnable: Whether the embeddings are learnable; if not provided,
276
+ uses sensible defaults.
277
+ base: The base for the sinusoidal embeddings.
278
+ key: The PRNG key for initializing learnable embeddings.
279
+
280
+ Returns:
281
+ The positional embeddings module.
282
+
283
+ Raises:
284
+ ValueError: If an invalid embedding kind is supplied.
285
+ """
286
+ match kind:
287
+ case "identity":
288
+ return IdentityPositionalEmbeddings()
289
+
290
+ case "learned":
291
+ assert max_tsz is not None, "Learned embeddings require `max_tsz` to be set"
292
+ assert embed_dim is not None, "Learned embeddings require `embed_dim` to be set"
293
+ assert key is not None, "Learned embeddings require `key` to be set"
294
+
295
+ return LearnedPositionalEmbeddings(
296
+ max_tsz=max_tsz,
297
+ embed_dim=embed_dim,
298
+ learnable=True if learnable is None else learnable,
299
+ key=key,
300
+ )
301
+
302
+ case "sinusoidal":
303
+ return SinusoidalEmbeddings(
304
+ max_tsz=max_tsz,
305
+ embed_dim=embed_dim,
306
+ learnable=False if learnable is None else learnable,
307
+ base=base,
308
+ )
309
+
310
+ case "rotary":
311
+ return RotaryEmbeddings(base=base)
312
+
313
+ case _:
314
+ raise ValueError(f"Invalid embedding kind: {kind}")
315
+
316
+
317
+ def fourier_embeddings(t: Array, dim: int, max_period: int = 10000) -> Array:
318
+ half = dim // 2
319
+ idxs = jnp.arange(start=0, stop=half, dtype=jnp.float32)
320
+ freqs = jnp.exp(-math.log(max_period) * idxs / half)
321
+ args = t[:, None].astype(jnp.float32) * freqs[None]
322
+ embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
323
+ # Adds an additional row of zeros to match the expected dimension.
324
+ if dim % 2:
325
+ embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
326
+ return embedding
327
+
328
+
329
+ class FourierEmbeddings(eqx.Module):
330
+ """Defines a module for applying Fourier embeddings to timesteps.
331
+
332
+ This module differs from the other positional embedding modules because it
333
+ expects a continuous time input, rather than a discrete time input.
334
+
335
+ Parameters:
336
+ dim: The number of embedding dimensions. This value is used to determine
337
+ how many different frequencies to use, and a higher value means
338
+ higher frequencies.
339
+ max_period: The maximum period for the embeddings. This should roughly
340
+ be in line with the maximum number of timesteps; the default value
341
+ of 10,000 is commonly used in NLP applications, and is derived from
342
+ operating on sequence lengths of 100 to 1000 tokens.
343
+ """
344
+
345
+ dim: int
346
+ max_period: int
347
+
348
+ def __init__(self, dim: int, max_period: int = 10000) -> None:
349
+ super().__init__()
350
+
351
+ self.dim = dim
352
+ self.max_period = max_period
353
+
354
+ def __call__(self, t: Array) -> Array:
355
+ return fourier_embeddings(t, self.dim, self.max_period)
xax/nn/functions.py CHANGED
@@ -5,6 +5,7 @@ import random
5
5
  from dataclasses import is_dataclass
6
6
  from typing import Any, Callable, Iterable, Mapping, ParamSpec, Sequence, TypeVar
7
7
 
8
+ import jax.numpy as jnp
8
9
  import numpy as np
9
10
  from jaxtyping import Array
10
11
 
@@ -14,26 +15,29 @@ T = TypeVar("T")
14
15
  P = ParamSpec("P")
15
16
 
16
17
 
17
- def recursive_apply(item: Any, func: Callable[[Array], Array]) -> Any: # noqa: ANN401
18
+ def recursive_apply(item: Any, func: Callable[[Array], Array], include_numpy: bool = False) -> Any: # noqa: ANN401
18
19
  """Applies a function recursively to tensors in an item.
19
20
 
20
21
  Args:
21
22
  item: The item to apply the function to
22
23
  func: The function to apply (for the tensor)
24
+ include_numpy: If set, include numpy arrays
23
25
 
24
26
  Returns:
25
27
  The same item, with the function applied
26
28
  """
27
29
  if isinstance(item, (str, int, float)):
28
30
  return item
31
+ if include_numpy and isinstance(item, np.ndarray):
32
+ return func(jnp.array(item))
29
33
  if isinstance(item, Array):
30
34
  return func(item)
31
35
  if is_dataclass(item):
32
- return item.__class__(**{k: recursive_apply(v, func) for k, v in item.__dict__.items()})
36
+ return item.__class__(**{k: recursive_apply(v, func, include_numpy) for k, v in item.__dict__.items()})
33
37
  if isinstance(item, Mapping):
34
- return {k: recursive_apply(v, func) for k, v in item.items()}
38
+ return {k: recursive_apply(v, func, include_numpy) for k, v in item.items()}
35
39
  if isinstance(item, Sequence):
36
- return [recursive_apply(i, func) for i in item]
40
+ return [recursive_apply(i, func, include_numpy) for i in item]
37
41
  return item
38
42
 
39
43
 
xax/requirements-dev.txt CHANGED
@@ -1,7 +1,15 @@
1
1
  # requirements-dev.txt
2
2
 
3
+ # Linting.
3
4
  black
4
5
  darglint
5
6
  mypy
6
- pytest
7
7
  ruff
8
+
9
+ # Tests.
10
+ pytest
11
+
12
+ # Types.
13
+ types-pillow
14
+ types-psutil
15
+ types-requests
xax/requirements.txt CHANGED
@@ -1,18 +1,25 @@
1
1
  # requirements.txt
2
2
 
3
- dpshdl
4
- equinox
5
- gitpython
3
+ # Core ML/JAX dependencies
6
4
  jax
7
5
  jaxtyping
8
- omegaconf
6
+ equinox
9
7
  optax
8
+ dpshdl
9
+ chex
10
+ importlib-resources
11
+
12
+ # Data processing and serialization
13
+ cloudpickle
10
14
  pillow
11
- psutil
12
- requests
15
+
16
+ # Configuration and project management
17
+ omegaconf
18
+ gitpython
19
+
20
+ # Monitoring and logging
13
21
  tensorboard
22
+ psutil
14
23
 
15
- # Types.
16
- types-pillow
17
- types-psutil
18
- types-requests
24
+ # Networking
25
+ requests
xax/task/base.py CHANGED
@@ -15,6 +15,7 @@ from pathlib import Path
15
15
  from types import TracebackType
16
16
  from typing import Generic, Self, TypeVar, cast
17
17
 
18
+ import jax
18
19
  from omegaconf import Container, DictConfig, OmegaConf
19
20
 
20
21
  from xax.core.state import State
@@ -23,6 +24,7 @@ from xax.utils.text import camelcase_to_snakecase
23
24
  logger = logging.getLogger(__name__)
24
25
 
25
26
 
27
+ @jax.tree_util.register_dataclass
26
28
  @dataclass
27
29
  class BaseConfig:
28
30
  pass
@@ -79,12 +81,6 @@ class BaseTask(Generic[Config]):
79
81
  def on_training_end(self, state: State) -> State:
80
82
  return state
81
83
 
82
- def on_before_save_checkpoint(self, ckpt_path: Path) -> None:
83
- pass
84
-
85
- def on_after_save_checkpoint(self, ckpt_path: Path) -> None:
86
- pass
87
-
88
84
  @functools.cached_property
89
85
  def task_class_name(self) -> str:
90
86
  return self.__class__.__name__