xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/nn/embeddings.py
ADDED
@@ -0,0 +1,355 @@
|
|
1
|
+
"""Defines embedding layers."""
|
2
|
+
|
3
|
+
import math
|
4
|
+
from typing import Literal, cast, get_args, overload
|
5
|
+
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import jax.random as jrandom
|
10
|
+
from jaxtyping import Array, DTypeLike, PRNGKeyArray
|
11
|
+
|
12
|
+
EmbeddingKind = Literal["identity", "learned", "sinusoidal", "rotary"]
|
13
|
+
|
14
|
+
|
15
|
+
def cast_embedding_kind(k: str) -> EmbeddingKind:
|
16
|
+
args = get_args(EmbeddingKind)
|
17
|
+
assert k in args, f"Invalid initialization type: '{k}' Valid options are {args}"
|
18
|
+
return cast(EmbeddingKind, k)
|
19
|
+
|
20
|
+
|
21
|
+
class IdentityPositionalEmbeddings(eqx.Module):
|
22
|
+
def __call__(self, x: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
23
|
+
return x
|
24
|
+
|
25
|
+
|
26
|
+
class LearnedPositionalEmbeddings(eqx.Module):
|
27
|
+
"""Defines a learned embeddings module.
|
28
|
+
|
29
|
+
Parameters:
|
30
|
+
max_tsz: The maximum sequence length.
|
31
|
+
embed_dim: The embedding dimension.
|
32
|
+
weight_init: The initialization type for the embedding weight.
|
33
|
+
learnable: Whether the embeddings are learnable.
|
34
|
+
"""
|
35
|
+
|
36
|
+
max_tsz: int = eqx.field(static=True)
|
37
|
+
embed_dim: int = eqx.field(static=True)
|
38
|
+
learnable: bool = eqx.field(static=True)
|
39
|
+
embeddings_tc: Array
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
max_tsz: int,
|
44
|
+
embed_dim: int,
|
45
|
+
learnable: bool = True,
|
46
|
+
*,
|
47
|
+
key: PRNGKeyArray,
|
48
|
+
) -> None:
|
49
|
+
super().__init__()
|
50
|
+
|
51
|
+
self.max_tsz = max_tsz
|
52
|
+
self.embed_dim = embed_dim
|
53
|
+
self.learnable = learnable
|
54
|
+
|
55
|
+
self.embeddings_tc = jrandom.normal(key, (max_tsz, embed_dim))
|
56
|
+
|
57
|
+
def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
58
|
+
if times_t is None:
|
59
|
+
emb_tc = self.embeddings_tc[offset : offset + x_tc.shape[-2]]
|
60
|
+
else:
|
61
|
+
emb_tc = self.embeddings_tc[times_t]
|
62
|
+
if not self.learnable:
|
63
|
+
emb_tc = jax.lax.stop_gradient(emb_tc)
|
64
|
+
return emb_tc
|
65
|
+
|
66
|
+
|
67
|
+
class SinusoidalEmbeddings(eqx.Module):
|
68
|
+
"""Defines a sinusoidal embeddings module.
|
69
|
+
|
70
|
+
Parameters:
|
71
|
+
embed_dim: The embedding dimension.
|
72
|
+
max_tsz: The maximum sequence length.
|
73
|
+
learnable: Whether the embeddings are learnable.
|
74
|
+
base: The base for the sinusoidal embeddings.
|
75
|
+
"""
|
76
|
+
|
77
|
+
base: int = eqx.field(static=True)
|
78
|
+
max_tsz: int | None = eqx.field(static=True)
|
79
|
+
embed_dim: int | None = eqx.field(static=True)
|
80
|
+
embeddings_tc: Array | None
|
81
|
+
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
embed_dim: int | None = None,
|
85
|
+
max_tsz: int | None = None,
|
86
|
+
learnable: bool = True,
|
87
|
+
base: int = 10_000,
|
88
|
+
) -> None:
|
89
|
+
super().__init__()
|
90
|
+
|
91
|
+
self.max_tsz = max_tsz
|
92
|
+
self.embed_dim = embed_dim
|
93
|
+
self.base = base
|
94
|
+
|
95
|
+
self.embeddings_tc: Array | None = None
|
96
|
+
if learnable:
|
97
|
+
assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
|
98
|
+
assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
|
99
|
+
self.embeddings_tc = self.get_embeddings(max_tsz, embed_dim)
|
100
|
+
|
101
|
+
def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
102
|
+
tsz, dims = x_tc.shape
|
103
|
+
|
104
|
+
# If the embeddings are learnable, use the property.
|
105
|
+
if self.embeddings_tc is None:
|
106
|
+
if times_t is None:
|
107
|
+
embeddings_tc = self.get_embeddings(offset + tsz, dims, x_tc.dtype)
|
108
|
+
else:
|
109
|
+
embeddings_tc = self.get_embeddings(times_t.max().item() + 1, dims, x_tc.dtype)
|
110
|
+
else:
|
111
|
+
embeddings_tc = self.embeddings_tc
|
112
|
+
|
113
|
+
# Get only the embeddings for the specified time steps.
|
114
|
+
if times_t is None:
|
115
|
+
embeddings_tc = embeddings_tc[offset : offset + tsz]
|
116
|
+
else:
|
117
|
+
embeddings_tc = embeddings_tc[times_t]
|
118
|
+
|
119
|
+
return x_tc + embeddings_tc
|
120
|
+
|
121
|
+
def get_embeddings(
|
122
|
+
self,
|
123
|
+
tsz: int,
|
124
|
+
embed_dim: int,
|
125
|
+
dtype: DTypeLike | None = None,
|
126
|
+
) -> Array:
|
127
|
+
positions_t = jax.numpy.arange(tsz, dtype=dtype)
|
128
|
+
dim_d = jax.numpy.arange(embed_dim, dtype=dtype)
|
129
|
+
dim_d = self.base ** (2 * (dim_d // 2) / embed_dim)
|
130
|
+
embeddings_td = positions_t[:, None] / dim_d[None, :]
|
131
|
+
embeddings_td = jnp.concatenate(
|
132
|
+
[jax.numpy.sin(embeddings_td[:, 0::2]), jax.numpy.cos(embeddings_td[:, 1::2])],
|
133
|
+
axis=-1,
|
134
|
+
)
|
135
|
+
return embeddings_td.astype(dtype)
|
136
|
+
|
137
|
+
|
138
|
+
def get_rotary_embeddings(
|
139
|
+
tsz: int,
|
140
|
+
embed_dim: int,
|
141
|
+
dtype: jnp.dtype,
|
142
|
+
offset: int = 0,
|
143
|
+
base: int = 10_000,
|
144
|
+
) -> Array:
|
145
|
+
assert embed_dim % 4 == 0, f"Embedding dimension must be divisible by 4, got {embed_dim}"
|
146
|
+
half_d = embed_dim // 2
|
147
|
+
theta = 1.0 / (base ** (jnp.arange(0, half_d, 2, dtype=jnp.float32) / half_d))
|
148
|
+
seq_idx = jnp.arange(offset, tsz + offset, dtype=jnp.float32)
|
149
|
+
idx_theta_tc = jnp.einsum("t,c->tc", seq_idx, theta)
|
150
|
+
idx_theta2_tc = jnp.concatenate([idx_theta_tc, idx_theta_tc], axis=1)
|
151
|
+
cos_tc, sin_tc = jnp.cos(idx_theta2_tc), jnp.sin(idx_theta2_tc)
|
152
|
+
emb_2tc = jnp.stack((cos_tc, sin_tc), axis=0)
|
153
|
+
return emb_2tc.astype(dtype)
|
154
|
+
|
155
|
+
|
156
|
+
def apply_rotary_embeddings(x_tc: Array, embs_2tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
157
|
+
cos_tc, sin_tc = embs_2tc[0], embs_2tc[1]
|
158
|
+
tsz, embed_dim = x_tc.shape
|
159
|
+
half_d = embed_dim // 2
|
160
|
+
quarter_d = embed_dim // 4
|
161
|
+
x_rope_tc, x_pass_tc = x_tc[..., :half_d], x_tc[..., half_d:]
|
162
|
+
neg_half_x_tc = jnp.concatenate([-x_rope_tc[..., quarter_d:], x_rope_tc[..., :quarter_d]], axis=-1)
|
163
|
+
cos_part_tc = cos_tc[offset : offset + tsz] if times_t is None else cos_tc[times_t]
|
164
|
+
sin_part_tc = sin_tc[offset : offset + tsz] if times_t is None else sin_tc[times_t]
|
165
|
+
x_rope_tc = x_rope_tc * cos_part_tc + neg_half_x_tc * sin_part_tc
|
166
|
+
return jnp.concatenate((x_rope_tc, x_pass_tc), axis=-1)
|
167
|
+
|
168
|
+
|
169
|
+
def rotary_embeddings(x_tc: Array, offset: int = 0, base: int = 10_000) -> Array:
|
170
|
+
"""Defines a single function for applying rotary embeddings.
|
171
|
+
|
172
|
+
This is slower than using the module, but it doesn't require
|
173
|
+
pre-initializing the embeddings, so it can be used when running online.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
x_tc: The input tensor, with shape ``(batch, tsz, embed_dim)``.
|
177
|
+
offset: The offset for the first element.
|
178
|
+
base: The base for the sinusoidal embeddings.
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
The input tensor with rotary embeddings applied.
|
182
|
+
"""
|
183
|
+
(tsz, embed_dim), dtype = x_tc.shape, x_tc.dtype
|
184
|
+
emb_2tc = get_rotary_embeddings(tsz + offset, embed_dim, dtype, 0, base)
|
185
|
+
return apply_rotary_embeddings(x_tc, emb_2tc, offset)
|
186
|
+
|
187
|
+
|
188
|
+
class RotaryEmbeddings(eqx.Module):
|
189
|
+
"""Defines a rotary embeddings module.
|
190
|
+
|
191
|
+
Parameters:
|
192
|
+
base: The base for the sinusoidal embeddings.
|
193
|
+
"""
|
194
|
+
|
195
|
+
base: int = eqx.field(static=True)
|
196
|
+
|
197
|
+
def __init__(self, base: int = 10_000) -> None:
|
198
|
+
"""Defines a rotary embeddings module.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
base: The base for the sinusoidal embeddings.
|
202
|
+
"""
|
203
|
+
super().__init__()
|
204
|
+
|
205
|
+
self.base = base
|
206
|
+
|
207
|
+
def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
|
208
|
+
tsz, embed_dim = x_tc.shape
|
209
|
+
max_tsz = max(tsz, 0 if times_t is None else int(times_t.max().item()) + 1) + offset
|
210
|
+
emb_2tc = get_rotary_embeddings(max_tsz, embed_dim, x_tc.dtype, 0, self.base)
|
211
|
+
return apply_rotary_embeddings(x_tc, emb_2tc, offset, times_t)
|
212
|
+
|
213
|
+
|
214
|
+
@overload
|
215
|
+
def get_positional_embeddings(kind: Literal["identity"]) -> IdentityPositionalEmbeddings: ...
|
216
|
+
|
217
|
+
|
218
|
+
@overload
|
219
|
+
def get_positional_embeddings(
|
220
|
+
kind: Literal["learned"],
|
221
|
+
*,
|
222
|
+
max_tsz: int,
|
223
|
+
embed_dim: int,
|
224
|
+
learnable: bool | None = None,
|
225
|
+
key: PRNGKeyArray,
|
226
|
+
) -> LearnedPositionalEmbeddings: ...
|
227
|
+
|
228
|
+
|
229
|
+
@overload
|
230
|
+
def get_positional_embeddings(
|
231
|
+
kind: Literal["sinusoidal"],
|
232
|
+
*,
|
233
|
+
max_tsz: int | None = None,
|
234
|
+
embed_dim: int | None = None,
|
235
|
+
learnable: bool | None = None,
|
236
|
+
base: int = 10_000,
|
237
|
+
) -> SinusoidalEmbeddings: ...
|
238
|
+
|
239
|
+
|
240
|
+
@overload
|
241
|
+
def get_positional_embeddings(
|
242
|
+
kind: Literal["rotary"],
|
243
|
+
*,
|
244
|
+
base: int = 10_000,
|
245
|
+
) -> RotaryEmbeddings: ...
|
246
|
+
|
247
|
+
|
248
|
+
@overload
|
249
|
+
def get_positional_embeddings(
|
250
|
+
kind: EmbeddingKind,
|
251
|
+
*,
|
252
|
+
max_tsz: int | None = None,
|
253
|
+
embed_dim: int | None = None,
|
254
|
+
learnable: bool | None = None,
|
255
|
+
base: int = 10_000,
|
256
|
+
key: PRNGKeyArray | None = None,
|
257
|
+
) -> IdentityPositionalEmbeddings | LearnedPositionalEmbeddings | SinusoidalEmbeddings | RotaryEmbeddings: ...
|
258
|
+
|
259
|
+
|
260
|
+
def get_positional_embeddings(
|
261
|
+
kind: EmbeddingKind,
|
262
|
+
*,
|
263
|
+
max_tsz: int | None = None,
|
264
|
+
embed_dim: int | None = None,
|
265
|
+
learnable: bool | None = None,
|
266
|
+
base: int = 10_000,
|
267
|
+
key: PRNGKeyArray | None = None,
|
268
|
+
) -> eqx.Module:
|
269
|
+
"""Defines the common module for adding positional embeddings.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
kind: The type of embedding to use.
|
273
|
+
max_tsz: The maximum sequence length.
|
274
|
+
embed_dim: The embedding dimension.
|
275
|
+
learnable: Whether the embeddings are learnable; if not provided,
|
276
|
+
uses sensible defaults.
|
277
|
+
base: The base for the sinusoidal embeddings.
|
278
|
+
key: The PRNG key for initializing learnable embeddings.
|
279
|
+
|
280
|
+
Returns:
|
281
|
+
The positional embeddings module.
|
282
|
+
|
283
|
+
Raises:
|
284
|
+
ValueError: If an invalid embedding kind is supplied.
|
285
|
+
"""
|
286
|
+
match kind:
|
287
|
+
case "identity":
|
288
|
+
return IdentityPositionalEmbeddings()
|
289
|
+
|
290
|
+
case "learned":
|
291
|
+
assert max_tsz is not None, "Learned embeddings require `max_tsz` to be set"
|
292
|
+
assert embed_dim is not None, "Learned embeddings require `embed_dim` to be set"
|
293
|
+
assert key is not None, "Learned embeddings require `key` to be set"
|
294
|
+
|
295
|
+
return LearnedPositionalEmbeddings(
|
296
|
+
max_tsz=max_tsz,
|
297
|
+
embed_dim=embed_dim,
|
298
|
+
learnable=True if learnable is None else learnable,
|
299
|
+
key=key,
|
300
|
+
)
|
301
|
+
|
302
|
+
case "sinusoidal":
|
303
|
+
return SinusoidalEmbeddings(
|
304
|
+
max_tsz=max_tsz,
|
305
|
+
embed_dim=embed_dim,
|
306
|
+
learnable=False if learnable is None else learnable,
|
307
|
+
base=base,
|
308
|
+
)
|
309
|
+
|
310
|
+
case "rotary":
|
311
|
+
return RotaryEmbeddings(base=base)
|
312
|
+
|
313
|
+
case _:
|
314
|
+
raise ValueError(f"Invalid embedding kind: {kind}")
|
315
|
+
|
316
|
+
|
317
|
+
def fourier_embeddings(t: Array, dim: int, max_period: int = 10000) -> Array:
|
318
|
+
half = dim // 2
|
319
|
+
idxs = jnp.arange(start=0, stop=half, dtype=jnp.float32)
|
320
|
+
freqs = jnp.exp(-math.log(max_period) * idxs / half)
|
321
|
+
args = t[:, None].astype(jnp.float32) * freqs[None]
|
322
|
+
embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
|
323
|
+
# Adds an additional row of zeros to match the expected dimension.
|
324
|
+
if dim % 2:
|
325
|
+
embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
|
326
|
+
return embedding
|
327
|
+
|
328
|
+
|
329
|
+
class FourierEmbeddings(eqx.Module):
|
330
|
+
"""Defines a module for applying Fourier embeddings to timesteps.
|
331
|
+
|
332
|
+
This module differs from the other positional embedding modules because it
|
333
|
+
expects a continuous time input, rather than a discrete time input.
|
334
|
+
|
335
|
+
Parameters:
|
336
|
+
dim: The number of embedding dimensions. This value is used to determine
|
337
|
+
how many different frequencies to use, and a higher value means
|
338
|
+
higher frequencies.
|
339
|
+
max_period: The maximum period for the embeddings. This should roughly
|
340
|
+
be in line with the maximum number of timesteps; the default value
|
341
|
+
of 10,000 is commonly used in NLP applications, and is derived from
|
342
|
+
operating on sequence lengths of 100 to 1000 tokens.
|
343
|
+
"""
|
344
|
+
|
345
|
+
dim: int
|
346
|
+
max_period: int
|
347
|
+
|
348
|
+
def __init__(self, dim: int, max_period: int = 10000) -> None:
|
349
|
+
super().__init__()
|
350
|
+
|
351
|
+
self.dim = dim
|
352
|
+
self.max_period = max_period
|
353
|
+
|
354
|
+
def __call__(self, t: Array) -> Array:
|
355
|
+
return fourier_embeddings(t, self.dim, self.max_period)
|
xax/nn/functions.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
# mypy: disable-error-code="override"
|
2
|
+
"""Defines helper Torch functions."""
|
3
|
+
|
4
|
+
import random
|
5
|
+
from dataclasses import is_dataclass
|
6
|
+
from typing import Any, Callable, Iterable, Mapping, ParamSpec, Sequence, TypeVar
|
7
|
+
|
8
|
+
import jax.numpy as jnp
|
9
|
+
import numpy as np
|
10
|
+
from jaxtyping import Array
|
11
|
+
|
12
|
+
from xax.core.conf import load_user_config
|
13
|
+
|
14
|
+
T = TypeVar("T")
|
15
|
+
P = ParamSpec("P")
|
16
|
+
|
17
|
+
|
18
|
+
def recursive_apply(item: Any, func: Callable[[Array], Array], include_numpy: bool = False) -> Any: # noqa: ANN401
|
19
|
+
"""Applies a function recursively to tensors in an item.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
item: The item to apply the function to
|
23
|
+
func: The function to apply (for the tensor)
|
24
|
+
include_numpy: If set, include numpy arrays
|
25
|
+
|
26
|
+
Returns:
|
27
|
+
The same item, with the function applied
|
28
|
+
"""
|
29
|
+
if isinstance(item, (str, int, float)):
|
30
|
+
return item
|
31
|
+
if include_numpy and isinstance(item, np.ndarray):
|
32
|
+
return func(jnp.array(item))
|
33
|
+
if isinstance(item, Array):
|
34
|
+
return func(item)
|
35
|
+
if is_dataclass(item):
|
36
|
+
return item.__class__(**{k: recursive_apply(v, func, include_numpy) for k, v in item.__dict__.items()})
|
37
|
+
if isinstance(item, Mapping):
|
38
|
+
return {k: recursive_apply(v, func, include_numpy) for k, v in item.items()}
|
39
|
+
if isinstance(item, Sequence):
|
40
|
+
return [recursive_apply(i, func, include_numpy) for i in item]
|
41
|
+
return item
|
42
|
+
|
43
|
+
|
44
|
+
def recursive_chunk(item: Any, num_chunks: int, dim: int = 0) -> Iterable[Any]: # noqa: ANN401
|
45
|
+
"""Recursively chunk tensors N times.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
item: The item to recursively chunk
|
49
|
+
num_chunks: The number of splits to make
|
50
|
+
dim: The split dimension
|
51
|
+
|
52
|
+
Yields:
|
53
|
+
``num_chunks`` chunks of items
|
54
|
+
"""
|
55
|
+
if isinstance(item, (str, int, float)):
|
56
|
+
yield from (item for _ in range(num_chunks))
|
57
|
+
elif isinstance(item, np.ndarray):
|
58
|
+
yield from np.array_split(item, num_chunks, axis=dim)
|
59
|
+
elif is_dataclass(item):
|
60
|
+
yield from (
|
61
|
+
item.__class__(**{k: i for k, i in zip(item.__dict__, ii)})
|
62
|
+
for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()))
|
63
|
+
)
|
64
|
+
elif isinstance(item, Mapping):
|
65
|
+
yield from (dict(zip(item, ii)) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values())))
|
66
|
+
elif isinstance(item, Sequence):
|
67
|
+
yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item)))
|
68
|
+
else:
|
69
|
+
yield from (item for _ in range(num_chunks))
|
70
|
+
|
71
|
+
|
72
|
+
def set_random_seed(seed: int | None = None, offset: int = 0) -> None:
|
73
|
+
if seed is None:
|
74
|
+
seed = load_user_config().experiment.default_random_seed
|
75
|
+
seed += offset
|
76
|
+
random.seed(seed)
|
77
|
+
np.random.seed(seed)
|
xax/nn/parallel.py
ADDED
@@ -0,0 +1,211 @@
|
|
1
|
+
"""Defines functions for parallelism."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
|
5
|
+
_RANK: int | None = None
|
6
|
+
_LOCAL_RANK: int | None = None
|
7
|
+
_WORLD_SIZE: int | None = None
|
8
|
+
_LOCAL_WORLD_SIZE: int | None = None
|
9
|
+
_MASTER_ADDR: str | None = None
|
10
|
+
_MASTER_PORT: int | None = None
|
11
|
+
_INIT_METHOD: str | None = None
|
12
|
+
|
13
|
+
|
14
|
+
def set_rank(rank: int) -> None:
|
15
|
+
global _RANK
|
16
|
+
|
17
|
+
if rank != _RANK:
|
18
|
+
_RANK = rank
|
19
|
+
os.environ["RANK"] = str(rank)
|
20
|
+
else:
|
21
|
+
raise ValueError(f"Rank {rank} is already set")
|
22
|
+
|
23
|
+
|
24
|
+
def get_rank_optional() -> int | None:
|
25
|
+
return _RANK
|
26
|
+
|
27
|
+
|
28
|
+
def get_rank() -> int:
|
29
|
+
return 0 if _RANK is None else _RANK
|
30
|
+
|
31
|
+
|
32
|
+
def clear_rank() -> None:
|
33
|
+
global _RANK
|
34
|
+
|
35
|
+
_RANK = None
|
36
|
+
os.environ.pop("RANK", None)
|
37
|
+
|
38
|
+
|
39
|
+
def set_local_rank(rank: int) -> None:
|
40
|
+
global _LOCAL_RANK
|
41
|
+
|
42
|
+
if rank != _LOCAL_RANK:
|
43
|
+
_LOCAL_RANK = rank
|
44
|
+
os.environ["LOCAL_RANK"] = str(rank)
|
45
|
+
else:
|
46
|
+
raise ValueError(f"Local rank {rank} is already set")
|
47
|
+
|
48
|
+
|
49
|
+
def get_local_rank_optional() -> int | None:
|
50
|
+
return _LOCAL_RANK
|
51
|
+
|
52
|
+
|
53
|
+
def get_local_rank() -> int:
|
54
|
+
return 0 if _LOCAL_RANK is None else _LOCAL_RANK
|
55
|
+
|
56
|
+
|
57
|
+
def clear_local_rank() -> None:
|
58
|
+
global _LOCAL_RANK
|
59
|
+
|
60
|
+
_LOCAL_RANK = None
|
61
|
+
os.environ.pop("LOCAL_RANK", None)
|
62
|
+
|
63
|
+
|
64
|
+
def set_world_size(world_size: int) -> None:
|
65
|
+
global _WORLD_SIZE
|
66
|
+
|
67
|
+
if world_size != _WORLD_SIZE:
|
68
|
+
_WORLD_SIZE = world_size
|
69
|
+
os.environ["WORLD_SIZE"] = str(world_size)
|
70
|
+
else:
|
71
|
+
raise ValueError(f"World size {world_size} is already set")
|
72
|
+
|
73
|
+
|
74
|
+
def get_world_size_optional() -> int | None:
|
75
|
+
return _WORLD_SIZE
|
76
|
+
|
77
|
+
|
78
|
+
def get_world_size() -> int:
|
79
|
+
return 1 if _WORLD_SIZE is None else _WORLD_SIZE
|
80
|
+
|
81
|
+
|
82
|
+
def clear_world_size() -> None:
|
83
|
+
global _WORLD_SIZE
|
84
|
+
|
85
|
+
_WORLD_SIZE = None
|
86
|
+
os.environ.pop("WORLD_SIZE", None)
|
87
|
+
|
88
|
+
|
89
|
+
def set_local_world_size(local_world_size: int) -> None:
|
90
|
+
global _LOCAL_WORLD_SIZE
|
91
|
+
|
92
|
+
if local_world_size != _LOCAL_WORLD_SIZE:
|
93
|
+
_LOCAL_WORLD_SIZE = local_world_size
|
94
|
+
os.environ["LOCAL_WORLD_SIZE"] = str(local_world_size)
|
95
|
+
else:
|
96
|
+
raise ValueError(f"World size {local_world_size} is already set")
|
97
|
+
|
98
|
+
|
99
|
+
def get_local_world_size_optional() -> int | None:
|
100
|
+
return _LOCAL_WORLD_SIZE
|
101
|
+
|
102
|
+
|
103
|
+
def get_local_world_size() -> int:
|
104
|
+
return 1 if _LOCAL_WORLD_SIZE is None else _LOCAL_WORLD_SIZE
|
105
|
+
|
106
|
+
|
107
|
+
def clear_local_world_size() -> None:
|
108
|
+
global _LOCAL_WORLD_SIZE
|
109
|
+
|
110
|
+
_LOCAL_WORLD_SIZE = None
|
111
|
+
os.environ.pop("LOCAL_WORLD_SIZE", None)
|
112
|
+
|
113
|
+
|
114
|
+
def set_master_addr(master_addr: str) -> None:
|
115
|
+
global _MASTER_ADDR
|
116
|
+
|
117
|
+
if master_addr != _MASTER_ADDR:
|
118
|
+
os.environ["MASTER_ADDR"] = _MASTER_ADDR = master_addr
|
119
|
+
else:
|
120
|
+
raise ValueError(f"Master address {master_addr} is already set")
|
121
|
+
|
122
|
+
|
123
|
+
def get_master_addr() -> str:
|
124
|
+
assert _MASTER_ADDR is not None, "Master address is not yet set"
|
125
|
+
return _MASTER_ADDR
|
126
|
+
|
127
|
+
|
128
|
+
def clear_master_addr() -> None:
|
129
|
+
global _MASTER_ADDR
|
130
|
+
|
131
|
+
_MASTER_ADDR = None
|
132
|
+
os.environ.pop("MASTER_ADDR", None)
|
133
|
+
|
134
|
+
|
135
|
+
def set_master_port(port: int) -> None:
|
136
|
+
global _MASTER_PORT
|
137
|
+
|
138
|
+
if port != _MASTER_PORT:
|
139
|
+
_MASTER_PORT = port
|
140
|
+
os.environ["MASTER_PORT"] = str(port)
|
141
|
+
else:
|
142
|
+
raise ValueError(f"Master port {port} is already set")
|
143
|
+
|
144
|
+
|
145
|
+
def get_master_port() -> int:
|
146
|
+
assert _MASTER_PORT is not None, "Master port is not yet set"
|
147
|
+
return _MASTER_PORT
|
148
|
+
|
149
|
+
|
150
|
+
def clear_master_port() -> None:
|
151
|
+
global _MASTER_PORT
|
152
|
+
|
153
|
+
_MASTER_PORT = None
|
154
|
+
os.environ.pop("MASTER_PORT", None)
|
155
|
+
|
156
|
+
|
157
|
+
def is_master() -> bool:
|
158
|
+
return get_rank() == 0
|
159
|
+
|
160
|
+
|
161
|
+
def is_distributed() -> bool:
|
162
|
+
return _INIT_METHOD is not None
|
163
|
+
|
164
|
+
|
165
|
+
def set_init_method(init_method: str) -> None:
|
166
|
+
global _INIT_METHOD
|
167
|
+
|
168
|
+
if init_method != _INIT_METHOD:
|
169
|
+
os.environ["INIT_METHOD"] = _INIT_METHOD = init_method
|
170
|
+
else:
|
171
|
+
raise ValueError(f"Init method {init_method} is already set")
|
172
|
+
|
173
|
+
|
174
|
+
def get_init_method() -> str:
|
175
|
+
assert _INIT_METHOD is not None, "Init method is not yet set"
|
176
|
+
return _INIT_METHOD
|
177
|
+
|
178
|
+
|
179
|
+
def clear_init_method() -> None:
|
180
|
+
global _INIT_METHOD
|
181
|
+
|
182
|
+
_INIT_METHOD = None
|
183
|
+
os.environ.pop("INIT_METHOD", None)
|
184
|
+
|
185
|
+
|
186
|
+
def set_dist(
|
187
|
+
rank: int,
|
188
|
+
local_rank: int,
|
189
|
+
world_size: int,
|
190
|
+
local_world_size: int,
|
191
|
+
master_addr: str,
|
192
|
+
master_port: int,
|
193
|
+
init_method: str,
|
194
|
+
) -> None:
|
195
|
+
set_rank(rank)
|
196
|
+
set_local_rank(local_rank)
|
197
|
+
set_world_size(world_size)
|
198
|
+
set_local_world_size(local_world_size)
|
199
|
+
set_master_addr(master_addr)
|
200
|
+
set_master_port(master_port)
|
201
|
+
set_init_method(init_method)
|
202
|
+
|
203
|
+
|
204
|
+
def clear_dist() -> None:
|
205
|
+
clear_rank()
|
206
|
+
clear_local_rank()
|
207
|
+
clear_world_size()
|
208
|
+
clear_local_world_size()
|
209
|
+
clear_master_addr()
|
210
|
+
clear_master_port()
|
211
|
+
clear_init_method()
|
xax/requirements-dev.txt
ADDED
xax/requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
# requirements.txt
|
2
|
+
|
3
|
+
# Core ML/JAX dependencies
|
4
|
+
jax
|
5
|
+
jaxtyping
|
6
|
+
equinox
|
7
|
+
optax
|
8
|
+
dpshdl
|
9
|
+
|
10
|
+
# Data processing and serialization
|
11
|
+
cloudpickle
|
12
|
+
pillow
|
13
|
+
|
14
|
+
# Configuration and project management
|
15
|
+
omegaconf
|
16
|
+
gitpython
|
17
|
+
|
18
|
+
# Monitoring and logging
|
19
|
+
tensorboard
|
20
|
+
psutil
|
21
|
+
|
22
|
+
# Networking
|
23
|
+
requests
|
xax/task/__init__.py
ADDED
File without changes
|