xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/nn/embeddings.py ADDED
@@ -0,0 +1,355 @@
1
+ """Defines embedding layers."""
2
+
3
+ import math
4
+ from typing import Literal, cast, get_args, overload
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import jax.random as jrandom
10
+ from jaxtyping import Array, DTypeLike, PRNGKeyArray
11
+
12
+ EmbeddingKind = Literal["identity", "learned", "sinusoidal", "rotary"]
13
+
14
+
15
+ def cast_embedding_kind(k: str) -> EmbeddingKind:
16
+ args = get_args(EmbeddingKind)
17
+ assert k in args, f"Invalid initialization type: '{k}' Valid options are {args}"
18
+ return cast(EmbeddingKind, k)
19
+
20
+
21
+ class IdentityPositionalEmbeddings(eqx.Module):
22
+ def __call__(self, x: Array, offset: int = 0, times_t: Array | None = None) -> Array:
23
+ return x
24
+
25
+
26
+ class LearnedPositionalEmbeddings(eqx.Module):
27
+ """Defines a learned embeddings module.
28
+
29
+ Parameters:
30
+ max_tsz: The maximum sequence length.
31
+ embed_dim: The embedding dimension.
32
+ weight_init: The initialization type for the embedding weight.
33
+ learnable: Whether the embeddings are learnable.
34
+ """
35
+
36
+ max_tsz: int = eqx.field(static=True)
37
+ embed_dim: int = eqx.field(static=True)
38
+ learnable: bool = eqx.field(static=True)
39
+ embeddings_tc: Array
40
+
41
+ def __init__(
42
+ self,
43
+ max_tsz: int,
44
+ embed_dim: int,
45
+ learnable: bool = True,
46
+ *,
47
+ key: PRNGKeyArray,
48
+ ) -> None:
49
+ super().__init__()
50
+
51
+ self.max_tsz = max_tsz
52
+ self.embed_dim = embed_dim
53
+ self.learnable = learnable
54
+
55
+ self.embeddings_tc = jrandom.normal(key, (max_tsz, embed_dim))
56
+
57
+ def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
58
+ if times_t is None:
59
+ emb_tc = self.embeddings_tc[offset : offset + x_tc.shape[-2]]
60
+ else:
61
+ emb_tc = self.embeddings_tc[times_t]
62
+ if not self.learnable:
63
+ emb_tc = jax.lax.stop_gradient(emb_tc)
64
+ return emb_tc
65
+
66
+
67
+ class SinusoidalEmbeddings(eqx.Module):
68
+ """Defines a sinusoidal embeddings module.
69
+
70
+ Parameters:
71
+ embed_dim: The embedding dimension.
72
+ max_tsz: The maximum sequence length.
73
+ learnable: Whether the embeddings are learnable.
74
+ base: The base for the sinusoidal embeddings.
75
+ """
76
+
77
+ base: int = eqx.field(static=True)
78
+ max_tsz: int | None = eqx.field(static=True)
79
+ embed_dim: int | None = eqx.field(static=True)
80
+ embeddings_tc: Array | None
81
+
82
+ def __init__(
83
+ self,
84
+ embed_dim: int | None = None,
85
+ max_tsz: int | None = None,
86
+ learnable: bool = True,
87
+ base: int = 10_000,
88
+ ) -> None:
89
+ super().__init__()
90
+
91
+ self.max_tsz = max_tsz
92
+ self.embed_dim = embed_dim
93
+ self.base = base
94
+
95
+ self.embeddings_tc: Array | None = None
96
+ if learnable:
97
+ assert max_tsz is not None, "Learnable parameters require `max_tsz` to be set"
98
+ assert embed_dim is not None, "Learnable parameters require `embed_dim` to be set"
99
+ self.embeddings_tc = self.get_embeddings(max_tsz, embed_dim)
100
+
101
+ def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
102
+ tsz, dims = x_tc.shape
103
+
104
+ # If the embeddings are learnable, use the property.
105
+ if self.embeddings_tc is None:
106
+ if times_t is None:
107
+ embeddings_tc = self.get_embeddings(offset + tsz, dims, x_tc.dtype)
108
+ else:
109
+ embeddings_tc = self.get_embeddings(times_t.max().item() + 1, dims, x_tc.dtype)
110
+ else:
111
+ embeddings_tc = self.embeddings_tc
112
+
113
+ # Get only the embeddings for the specified time steps.
114
+ if times_t is None:
115
+ embeddings_tc = embeddings_tc[offset : offset + tsz]
116
+ else:
117
+ embeddings_tc = embeddings_tc[times_t]
118
+
119
+ return x_tc + embeddings_tc
120
+
121
+ def get_embeddings(
122
+ self,
123
+ tsz: int,
124
+ embed_dim: int,
125
+ dtype: DTypeLike | None = None,
126
+ ) -> Array:
127
+ positions_t = jax.numpy.arange(tsz, dtype=dtype)
128
+ dim_d = jax.numpy.arange(embed_dim, dtype=dtype)
129
+ dim_d = self.base ** (2 * (dim_d // 2) / embed_dim)
130
+ embeddings_td = positions_t[:, None] / dim_d[None, :]
131
+ embeddings_td = jnp.concatenate(
132
+ [jax.numpy.sin(embeddings_td[:, 0::2]), jax.numpy.cos(embeddings_td[:, 1::2])],
133
+ axis=-1,
134
+ )
135
+ return embeddings_td.astype(dtype)
136
+
137
+
138
+ def get_rotary_embeddings(
139
+ tsz: int,
140
+ embed_dim: int,
141
+ dtype: jnp.dtype,
142
+ offset: int = 0,
143
+ base: int = 10_000,
144
+ ) -> Array:
145
+ assert embed_dim % 4 == 0, f"Embedding dimension must be divisible by 4, got {embed_dim}"
146
+ half_d = embed_dim // 2
147
+ theta = 1.0 / (base ** (jnp.arange(0, half_d, 2, dtype=jnp.float32) / half_d))
148
+ seq_idx = jnp.arange(offset, tsz + offset, dtype=jnp.float32)
149
+ idx_theta_tc = jnp.einsum("t,c->tc", seq_idx, theta)
150
+ idx_theta2_tc = jnp.concatenate([idx_theta_tc, idx_theta_tc], axis=1)
151
+ cos_tc, sin_tc = jnp.cos(idx_theta2_tc), jnp.sin(idx_theta2_tc)
152
+ emb_2tc = jnp.stack((cos_tc, sin_tc), axis=0)
153
+ return emb_2tc.astype(dtype)
154
+
155
+
156
+ def apply_rotary_embeddings(x_tc: Array, embs_2tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
157
+ cos_tc, sin_tc = embs_2tc[0], embs_2tc[1]
158
+ tsz, embed_dim = x_tc.shape
159
+ half_d = embed_dim // 2
160
+ quarter_d = embed_dim // 4
161
+ x_rope_tc, x_pass_tc = x_tc[..., :half_d], x_tc[..., half_d:]
162
+ neg_half_x_tc = jnp.concatenate([-x_rope_tc[..., quarter_d:], x_rope_tc[..., :quarter_d]], axis=-1)
163
+ cos_part_tc = cos_tc[offset : offset + tsz] if times_t is None else cos_tc[times_t]
164
+ sin_part_tc = sin_tc[offset : offset + tsz] if times_t is None else sin_tc[times_t]
165
+ x_rope_tc = x_rope_tc * cos_part_tc + neg_half_x_tc * sin_part_tc
166
+ return jnp.concatenate((x_rope_tc, x_pass_tc), axis=-1)
167
+
168
+
169
+ def rotary_embeddings(x_tc: Array, offset: int = 0, base: int = 10_000) -> Array:
170
+ """Defines a single function for applying rotary embeddings.
171
+
172
+ This is slower than using the module, but it doesn't require
173
+ pre-initializing the embeddings, so it can be used when running online.
174
+
175
+ Args:
176
+ x_tc: The input tensor, with shape ``(batch, tsz, embed_dim)``.
177
+ offset: The offset for the first element.
178
+ base: The base for the sinusoidal embeddings.
179
+
180
+ Returns:
181
+ The input tensor with rotary embeddings applied.
182
+ """
183
+ (tsz, embed_dim), dtype = x_tc.shape, x_tc.dtype
184
+ emb_2tc = get_rotary_embeddings(tsz + offset, embed_dim, dtype, 0, base)
185
+ return apply_rotary_embeddings(x_tc, emb_2tc, offset)
186
+
187
+
188
+ class RotaryEmbeddings(eqx.Module):
189
+ """Defines a rotary embeddings module.
190
+
191
+ Parameters:
192
+ base: The base for the sinusoidal embeddings.
193
+ """
194
+
195
+ base: int = eqx.field(static=True)
196
+
197
+ def __init__(self, base: int = 10_000) -> None:
198
+ """Defines a rotary embeddings module.
199
+
200
+ Args:
201
+ base: The base for the sinusoidal embeddings.
202
+ """
203
+ super().__init__()
204
+
205
+ self.base = base
206
+
207
+ def __call__(self, x_tc: Array, offset: int = 0, times_t: Array | None = None) -> Array:
208
+ tsz, embed_dim = x_tc.shape
209
+ max_tsz = max(tsz, 0 if times_t is None else int(times_t.max().item()) + 1) + offset
210
+ emb_2tc = get_rotary_embeddings(max_tsz, embed_dim, x_tc.dtype, 0, self.base)
211
+ return apply_rotary_embeddings(x_tc, emb_2tc, offset, times_t)
212
+
213
+
214
+ @overload
215
+ def get_positional_embeddings(kind: Literal["identity"]) -> IdentityPositionalEmbeddings: ...
216
+
217
+
218
+ @overload
219
+ def get_positional_embeddings(
220
+ kind: Literal["learned"],
221
+ *,
222
+ max_tsz: int,
223
+ embed_dim: int,
224
+ learnable: bool | None = None,
225
+ key: PRNGKeyArray,
226
+ ) -> LearnedPositionalEmbeddings: ...
227
+
228
+
229
+ @overload
230
+ def get_positional_embeddings(
231
+ kind: Literal["sinusoidal"],
232
+ *,
233
+ max_tsz: int | None = None,
234
+ embed_dim: int | None = None,
235
+ learnable: bool | None = None,
236
+ base: int = 10_000,
237
+ ) -> SinusoidalEmbeddings: ...
238
+
239
+
240
+ @overload
241
+ def get_positional_embeddings(
242
+ kind: Literal["rotary"],
243
+ *,
244
+ base: int = 10_000,
245
+ ) -> RotaryEmbeddings: ...
246
+
247
+
248
+ @overload
249
+ def get_positional_embeddings(
250
+ kind: EmbeddingKind,
251
+ *,
252
+ max_tsz: int | None = None,
253
+ embed_dim: int | None = None,
254
+ learnable: bool | None = None,
255
+ base: int = 10_000,
256
+ key: PRNGKeyArray | None = None,
257
+ ) -> IdentityPositionalEmbeddings | LearnedPositionalEmbeddings | SinusoidalEmbeddings | RotaryEmbeddings: ...
258
+
259
+
260
+ def get_positional_embeddings(
261
+ kind: EmbeddingKind,
262
+ *,
263
+ max_tsz: int | None = None,
264
+ embed_dim: int | None = None,
265
+ learnable: bool | None = None,
266
+ base: int = 10_000,
267
+ key: PRNGKeyArray | None = None,
268
+ ) -> eqx.Module:
269
+ """Defines the common module for adding positional embeddings.
270
+
271
+ Args:
272
+ kind: The type of embedding to use.
273
+ max_tsz: The maximum sequence length.
274
+ embed_dim: The embedding dimension.
275
+ learnable: Whether the embeddings are learnable; if not provided,
276
+ uses sensible defaults.
277
+ base: The base for the sinusoidal embeddings.
278
+ key: The PRNG key for initializing learnable embeddings.
279
+
280
+ Returns:
281
+ The positional embeddings module.
282
+
283
+ Raises:
284
+ ValueError: If an invalid embedding kind is supplied.
285
+ """
286
+ match kind:
287
+ case "identity":
288
+ return IdentityPositionalEmbeddings()
289
+
290
+ case "learned":
291
+ assert max_tsz is not None, "Learned embeddings require `max_tsz` to be set"
292
+ assert embed_dim is not None, "Learned embeddings require `embed_dim` to be set"
293
+ assert key is not None, "Learned embeddings require `key` to be set"
294
+
295
+ return LearnedPositionalEmbeddings(
296
+ max_tsz=max_tsz,
297
+ embed_dim=embed_dim,
298
+ learnable=True if learnable is None else learnable,
299
+ key=key,
300
+ )
301
+
302
+ case "sinusoidal":
303
+ return SinusoidalEmbeddings(
304
+ max_tsz=max_tsz,
305
+ embed_dim=embed_dim,
306
+ learnable=False if learnable is None else learnable,
307
+ base=base,
308
+ )
309
+
310
+ case "rotary":
311
+ return RotaryEmbeddings(base=base)
312
+
313
+ case _:
314
+ raise ValueError(f"Invalid embedding kind: {kind}")
315
+
316
+
317
+ def fourier_embeddings(t: Array, dim: int, max_period: int = 10000) -> Array:
318
+ half = dim // 2
319
+ idxs = jnp.arange(start=0, stop=half, dtype=jnp.float32)
320
+ freqs = jnp.exp(-math.log(max_period) * idxs / half)
321
+ args = t[:, None].astype(jnp.float32) * freqs[None]
322
+ embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
323
+ # Adds an additional row of zeros to match the expected dimension.
324
+ if dim % 2:
325
+ embedding = jnp.concatenate([embedding, jnp.zeros_like(embedding[:, :1])], axis=-1)
326
+ return embedding
327
+
328
+
329
+ class FourierEmbeddings(eqx.Module):
330
+ """Defines a module for applying Fourier embeddings to timesteps.
331
+
332
+ This module differs from the other positional embedding modules because it
333
+ expects a continuous time input, rather than a discrete time input.
334
+
335
+ Parameters:
336
+ dim: The number of embedding dimensions. This value is used to determine
337
+ how many different frequencies to use, and a higher value means
338
+ higher frequencies.
339
+ max_period: The maximum period for the embeddings. This should roughly
340
+ be in line with the maximum number of timesteps; the default value
341
+ of 10,000 is commonly used in NLP applications, and is derived from
342
+ operating on sequence lengths of 100 to 1000 tokens.
343
+ """
344
+
345
+ dim: int
346
+ max_period: int
347
+
348
+ def __init__(self, dim: int, max_period: int = 10000) -> None:
349
+ super().__init__()
350
+
351
+ self.dim = dim
352
+ self.max_period = max_period
353
+
354
+ def __call__(self, t: Array) -> Array:
355
+ return fourier_embeddings(t, self.dim, self.max_period)
xax/nn/functions.py ADDED
@@ -0,0 +1,77 @@
1
+ # mypy: disable-error-code="override"
2
+ """Defines helper Torch functions."""
3
+
4
+ import random
5
+ from dataclasses import is_dataclass
6
+ from typing import Any, Callable, Iterable, Mapping, ParamSpec, Sequence, TypeVar
7
+
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ from jaxtyping import Array
11
+
12
+ from xax.core.conf import load_user_config
13
+
14
+ T = TypeVar("T")
15
+ P = ParamSpec("P")
16
+
17
+
18
+ def recursive_apply(item: Any, func: Callable[[Array], Array], include_numpy: bool = False) -> Any: # noqa: ANN401
19
+ """Applies a function recursively to tensors in an item.
20
+
21
+ Args:
22
+ item: The item to apply the function to
23
+ func: The function to apply (for the tensor)
24
+ include_numpy: If set, include numpy arrays
25
+
26
+ Returns:
27
+ The same item, with the function applied
28
+ """
29
+ if isinstance(item, (str, int, float)):
30
+ return item
31
+ if include_numpy and isinstance(item, np.ndarray):
32
+ return func(jnp.array(item))
33
+ if isinstance(item, Array):
34
+ return func(item)
35
+ if is_dataclass(item):
36
+ return item.__class__(**{k: recursive_apply(v, func, include_numpy) for k, v in item.__dict__.items()})
37
+ if isinstance(item, Mapping):
38
+ return {k: recursive_apply(v, func, include_numpy) for k, v in item.items()}
39
+ if isinstance(item, Sequence):
40
+ return [recursive_apply(i, func, include_numpy) for i in item]
41
+ return item
42
+
43
+
44
+ def recursive_chunk(item: Any, num_chunks: int, dim: int = 0) -> Iterable[Any]: # noqa: ANN401
45
+ """Recursively chunk tensors N times.
46
+
47
+ Args:
48
+ item: The item to recursively chunk
49
+ num_chunks: The number of splits to make
50
+ dim: The split dimension
51
+
52
+ Yields:
53
+ ``num_chunks`` chunks of items
54
+ """
55
+ if isinstance(item, (str, int, float)):
56
+ yield from (item for _ in range(num_chunks))
57
+ elif isinstance(item, np.ndarray):
58
+ yield from np.array_split(item, num_chunks, axis=dim)
59
+ elif is_dataclass(item):
60
+ yield from (
61
+ item.__class__(**{k: i for k, i in zip(item.__dict__, ii)})
62
+ for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()))
63
+ )
64
+ elif isinstance(item, Mapping):
65
+ yield from (dict(zip(item, ii)) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values())))
66
+ elif isinstance(item, Sequence):
67
+ yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item)))
68
+ else:
69
+ yield from (item for _ in range(num_chunks))
70
+
71
+
72
+ def set_random_seed(seed: int | None = None, offset: int = 0) -> None:
73
+ if seed is None:
74
+ seed = load_user_config().experiment.default_random_seed
75
+ seed += offset
76
+ random.seed(seed)
77
+ np.random.seed(seed)
xax/nn/parallel.py ADDED
@@ -0,0 +1,211 @@
1
+ """Defines functions for parallelism."""
2
+
3
+ import os
4
+
5
+ _RANK: int | None = None
6
+ _LOCAL_RANK: int | None = None
7
+ _WORLD_SIZE: int | None = None
8
+ _LOCAL_WORLD_SIZE: int | None = None
9
+ _MASTER_ADDR: str | None = None
10
+ _MASTER_PORT: int | None = None
11
+ _INIT_METHOD: str | None = None
12
+
13
+
14
+ def set_rank(rank: int) -> None:
15
+ global _RANK
16
+
17
+ if rank != _RANK:
18
+ _RANK = rank
19
+ os.environ["RANK"] = str(rank)
20
+ else:
21
+ raise ValueError(f"Rank {rank} is already set")
22
+
23
+
24
+ def get_rank_optional() -> int | None:
25
+ return _RANK
26
+
27
+
28
+ def get_rank() -> int:
29
+ return 0 if _RANK is None else _RANK
30
+
31
+
32
+ def clear_rank() -> None:
33
+ global _RANK
34
+
35
+ _RANK = None
36
+ os.environ.pop("RANK", None)
37
+
38
+
39
+ def set_local_rank(rank: int) -> None:
40
+ global _LOCAL_RANK
41
+
42
+ if rank != _LOCAL_RANK:
43
+ _LOCAL_RANK = rank
44
+ os.environ["LOCAL_RANK"] = str(rank)
45
+ else:
46
+ raise ValueError(f"Local rank {rank} is already set")
47
+
48
+
49
+ def get_local_rank_optional() -> int | None:
50
+ return _LOCAL_RANK
51
+
52
+
53
+ def get_local_rank() -> int:
54
+ return 0 if _LOCAL_RANK is None else _LOCAL_RANK
55
+
56
+
57
+ def clear_local_rank() -> None:
58
+ global _LOCAL_RANK
59
+
60
+ _LOCAL_RANK = None
61
+ os.environ.pop("LOCAL_RANK", None)
62
+
63
+
64
+ def set_world_size(world_size: int) -> None:
65
+ global _WORLD_SIZE
66
+
67
+ if world_size != _WORLD_SIZE:
68
+ _WORLD_SIZE = world_size
69
+ os.environ["WORLD_SIZE"] = str(world_size)
70
+ else:
71
+ raise ValueError(f"World size {world_size} is already set")
72
+
73
+
74
+ def get_world_size_optional() -> int | None:
75
+ return _WORLD_SIZE
76
+
77
+
78
+ def get_world_size() -> int:
79
+ return 1 if _WORLD_SIZE is None else _WORLD_SIZE
80
+
81
+
82
+ def clear_world_size() -> None:
83
+ global _WORLD_SIZE
84
+
85
+ _WORLD_SIZE = None
86
+ os.environ.pop("WORLD_SIZE", None)
87
+
88
+
89
+ def set_local_world_size(local_world_size: int) -> None:
90
+ global _LOCAL_WORLD_SIZE
91
+
92
+ if local_world_size != _LOCAL_WORLD_SIZE:
93
+ _LOCAL_WORLD_SIZE = local_world_size
94
+ os.environ["LOCAL_WORLD_SIZE"] = str(local_world_size)
95
+ else:
96
+ raise ValueError(f"World size {local_world_size} is already set")
97
+
98
+
99
+ def get_local_world_size_optional() -> int | None:
100
+ return _LOCAL_WORLD_SIZE
101
+
102
+
103
+ def get_local_world_size() -> int:
104
+ return 1 if _LOCAL_WORLD_SIZE is None else _LOCAL_WORLD_SIZE
105
+
106
+
107
+ def clear_local_world_size() -> None:
108
+ global _LOCAL_WORLD_SIZE
109
+
110
+ _LOCAL_WORLD_SIZE = None
111
+ os.environ.pop("LOCAL_WORLD_SIZE", None)
112
+
113
+
114
+ def set_master_addr(master_addr: str) -> None:
115
+ global _MASTER_ADDR
116
+
117
+ if master_addr != _MASTER_ADDR:
118
+ os.environ["MASTER_ADDR"] = _MASTER_ADDR = master_addr
119
+ else:
120
+ raise ValueError(f"Master address {master_addr} is already set")
121
+
122
+
123
+ def get_master_addr() -> str:
124
+ assert _MASTER_ADDR is not None, "Master address is not yet set"
125
+ return _MASTER_ADDR
126
+
127
+
128
+ def clear_master_addr() -> None:
129
+ global _MASTER_ADDR
130
+
131
+ _MASTER_ADDR = None
132
+ os.environ.pop("MASTER_ADDR", None)
133
+
134
+
135
+ def set_master_port(port: int) -> None:
136
+ global _MASTER_PORT
137
+
138
+ if port != _MASTER_PORT:
139
+ _MASTER_PORT = port
140
+ os.environ["MASTER_PORT"] = str(port)
141
+ else:
142
+ raise ValueError(f"Master port {port} is already set")
143
+
144
+
145
+ def get_master_port() -> int:
146
+ assert _MASTER_PORT is not None, "Master port is not yet set"
147
+ return _MASTER_PORT
148
+
149
+
150
+ def clear_master_port() -> None:
151
+ global _MASTER_PORT
152
+
153
+ _MASTER_PORT = None
154
+ os.environ.pop("MASTER_PORT", None)
155
+
156
+
157
+ def is_master() -> bool:
158
+ return get_rank() == 0
159
+
160
+
161
+ def is_distributed() -> bool:
162
+ return _INIT_METHOD is not None
163
+
164
+
165
+ def set_init_method(init_method: str) -> None:
166
+ global _INIT_METHOD
167
+
168
+ if init_method != _INIT_METHOD:
169
+ os.environ["INIT_METHOD"] = _INIT_METHOD = init_method
170
+ else:
171
+ raise ValueError(f"Init method {init_method} is already set")
172
+
173
+
174
+ def get_init_method() -> str:
175
+ assert _INIT_METHOD is not None, "Init method is not yet set"
176
+ return _INIT_METHOD
177
+
178
+
179
+ def clear_init_method() -> None:
180
+ global _INIT_METHOD
181
+
182
+ _INIT_METHOD = None
183
+ os.environ.pop("INIT_METHOD", None)
184
+
185
+
186
+ def set_dist(
187
+ rank: int,
188
+ local_rank: int,
189
+ world_size: int,
190
+ local_world_size: int,
191
+ master_addr: str,
192
+ master_port: int,
193
+ init_method: str,
194
+ ) -> None:
195
+ set_rank(rank)
196
+ set_local_rank(local_rank)
197
+ set_world_size(world_size)
198
+ set_local_world_size(local_world_size)
199
+ set_master_addr(master_addr)
200
+ set_master_port(master_port)
201
+ set_init_method(init_method)
202
+
203
+
204
+ def clear_dist() -> None:
205
+ clear_rank()
206
+ clear_local_rank()
207
+ clear_world_size()
208
+ clear_local_world_size()
209
+ clear_master_addr()
210
+ clear_master_port()
211
+ clear_init_method()
@@ -0,0 +1,15 @@
1
+ # requirements-dev.txt
2
+
3
+ # Linting.
4
+ black
5
+ darglint
6
+ mypy
7
+ ruff
8
+
9
+ # Tests.
10
+ pytest
11
+
12
+ # Types.
13
+ types-pillow
14
+ types-psutil
15
+ types-requests
xax/requirements.txt ADDED
@@ -0,0 +1,23 @@
1
+ # requirements.txt
2
+
3
+ # Core ML/JAX dependencies
4
+ jax
5
+ jaxtyping
6
+ equinox
7
+ optax
8
+ dpshdl
9
+
10
+ # Data processing and serialization
11
+ cloudpickle
12
+ pillow
13
+
14
+ # Configuration and project management
15
+ omegaconf
16
+ gitpython
17
+
18
+ # Monitoring and logging
19
+ tensorboard
20
+ psutil
21
+
22
+ # Networking
23
+ requests
xax/task/__init__.py ADDED
File without changes