xax 0.2.22__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.22"
15
+ __version__ = "0.3.0"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -23,6 +23,10 @@ __all__ = [
23
23
  "get_run_dir",
24
24
  "load_user_config",
25
25
  "State",
26
+ "CrossAttentionBlock",
27
+ "SelfAttentionBlock",
28
+ "Transformer",
29
+ "TransformerBlock",
26
30
  "FourierEmbeddings",
27
31
  "IdentityPositionalEmbeddings",
28
32
  "LearnedPositionalEmbeddings",
@@ -200,6 +204,10 @@ NAME_MAP: dict[str, str] = {
200
204
  "get_run_dir": "core.conf",
201
205
  "load_user_config": "core.conf",
202
206
  "State": "core.state",
207
+ "CrossAttentionBlock": "nn.attention",
208
+ "SelfAttentionBlock": "nn.attention",
209
+ "Transformer": "nn.attention",
210
+ "TransformerBlock": "nn.attention",
203
211
  "FourierEmbeddings": "nn.embeddings",
204
212
  "IdentityPositionalEmbeddings": "nn.embeddings",
205
213
  "LearnedPositionalEmbeddings": "nn.embeddings",
@@ -370,6 +378,7 @@ if IMPORT_ALL or TYPE_CHECKING:
370
378
  load_user_config,
371
379
  )
372
380
  from xax.core.state import Phase, State
381
+ from xax.nn.attention import CrossAttentionBlock, SelfAttentionBlock, Transformer, TransformerBlock
373
382
  from xax.nn.embeddings import (
374
383
  EmbeddingKind,
375
384
  FourierEmbeddings,
xax/core/state.py CHANGED
@@ -27,11 +27,8 @@ def _int_to_phase(i: int) -> Phase:
27
27
  class StateDict(TypedDict, total=False):
28
28
  num_steps: NotRequired[int | Array]
29
29
  num_samples: NotRequired[int | Array]
30
- num_valid_steps: NotRequired[int | Array]
31
- num_valid_samples: NotRequired[int | Array]
32
30
  start_time_s: NotRequired[float | Array]
33
31
  elapsed_time_s: NotRequired[float | Array]
34
- valid_elapsed_time_s: NotRequired[float | Array]
35
32
  phase: NotRequired[Phase]
36
33
  _phase: NotRequired[int | Array]
37
34
 
@@ -47,38 +44,26 @@ class State:
47
44
  return self._int32_arr[0]
48
45
 
49
46
  @property
50
- def num_valid_steps(self) -> Array:
51
- return self._int32_arr[1]
47
+ def phase(self) -> Phase:
48
+ return _int_to_phase(self._int32_arr[1].item())
52
49
 
53
50
  @property
54
51
  def num_samples(self) -> Array:
55
52
  return self._float32_arr[0]
56
53
 
57
- @property
58
- def num_valid_samples(self) -> Array:
59
- return self._float32_arr[1]
60
-
61
54
  @property
62
55
  def start_time_s(self) -> Array:
63
- return self._float32_arr[2]
56
+ return self._float32_arr[1]
64
57
 
65
58
  @property
66
59
  def elapsed_time_s(self) -> Array:
67
- return self._float32_arr[3]
68
-
69
- @property
70
- def valid_elapsed_time_s(self) -> Array:
71
- return self._float32_arr[4]
72
-
73
- @property
74
- def phase(self) -> Phase:
75
- return _int_to_phase(self._int32_arr[2].item())
60
+ return self._float32_arr[2]
76
61
 
77
62
  @classmethod
78
63
  def init_state(cls) -> "State":
79
64
  return cls(
80
- _int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
81
- _float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0, 0.0], dtype=jnp.float32),
65
+ _int32_arr=jnp.array([0, 0], dtype=jnp.int32),
66
+ _float32_arr=jnp.array([0.0, time.time(), 0.0], dtype=jnp.float32),
82
67
  )
83
68
 
84
69
  @property
@@ -91,25 +76,19 @@ class State:
91
76
 
92
77
  if "num_steps" in kwargs:
93
78
  int32_arr = int32_arr.at[0].set(kwargs["num_steps"])
94
- if "num_valid_steps" in kwargs:
95
- int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
96
79
 
97
80
  if "phase" in kwargs:
98
- int32_arr = int32_arr.at[2].set(_phase_to_int(kwargs["phase"]))
81
+ int32_arr = int32_arr.at[1].set(_phase_to_int(kwargs["phase"]))
99
82
  if "_phase" in kwargs:
100
- int32_arr = int32_arr.at[2].set(kwargs["_phase"])
83
+ int32_arr = int32_arr.at[1].set(kwargs["_phase"])
101
84
 
102
85
  if "num_samples" in kwargs:
103
86
  float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
104
- if "num_valid_samples" in kwargs:
105
- float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
106
87
 
107
88
  if "start_time_s" in kwargs:
108
- float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
89
+ float32_arr = float32_arr.at[1].set(kwargs["start_time_s"])
109
90
  if "elapsed_time_s" in kwargs:
110
- float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
111
- if "valid_elapsed_time_s" in kwargs:
112
- float32_arr = float32_arr.at[4].set(kwargs["valid_elapsed_time_s"])
91
+ float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
113
92
 
114
93
  return State(
115
94
  _int32_arr=int32_arr,
@@ -119,12 +98,9 @@ class State:
119
98
  def to_dict(self) -> dict[str, int | float | str]:
120
99
  return {
121
100
  "num_steps": int(self.num_steps.item()),
122
- "num_valid_steps": int(self.num_valid_steps.item()),
123
101
  "num_samples": int(self.num_samples.item()),
124
- "num_valid_samples": int(self.num_valid_samples.item()),
125
102
  "start_time_s": float(self.start_time_s.item()),
126
103
  "elapsed_time_s": float(self.elapsed_time_s.item()),
127
- "valid_elapsed_time_s": float(self.valid_elapsed_time_s.item()),
128
104
  "phase": str(self.phase),
129
105
  }
130
106
 
@@ -136,7 +112,6 @@ class State:
136
112
  int32_arr = jnp.array(
137
113
  [
138
114
  d.get("num_steps", 0),
139
- d.get("num_valid_steps", 0),
140
115
  d.get("_phase", 0),
141
116
  ],
142
117
  dtype=jnp.int32,
@@ -145,10 +120,8 @@ class State:
145
120
  float32_arr = jnp.array(
146
121
  [
147
122
  d.get("num_samples", 0),
148
- d.get("num_valid_samples", 0),
149
123
  d.get("start_time_s", time.time()),
150
124
  d.get("elapsed_time_s", 0.0),
151
- d.get("valid_elapsed_time_s", 0.0),
152
125
  ],
153
126
  dtype=jnp.float32,
154
127
  )
xax/nn/attention.py ADDED
@@ -0,0 +1,738 @@
1
+ """Attention mechanisms for transformer models."""
2
+
3
+ from typing import Literal, cast, overload
4
+
5
+ import chex
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jaxtyping import Array, PRNGKeyArray
10
+
11
+
12
+ class SelfAttentionBlock(eqx.Module):
13
+ """Self-attention block using jax.nn.dot_product_attention."""
14
+
15
+ q_proj: eqx.nn.Linear
16
+ k_proj: eqx.nn.Linear
17
+ v_proj: eqx.nn.Linear
18
+ output_proj: eqx.nn.Linear
19
+ num_heads: int = eqx.static_field()
20
+ head_dim: int = eqx.static_field()
21
+ causal: bool = eqx.static_field()
22
+
23
+ def __init__(
24
+ self,
25
+ embed_dim: int,
26
+ num_heads: int,
27
+ *,
28
+ key: PRNGKeyArray,
29
+ causal: bool = False,
30
+ ) -> None:
31
+ keys = jax.random.split(key, 4)
32
+
33
+ self.num_heads = num_heads
34
+ self.head_dim = embed_dim // num_heads
35
+ assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
36
+
37
+ self.q_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[0])
38
+ self.k_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[1])
39
+ self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
40
+ self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
41
+
42
+ self.causal = causal
43
+
44
+ def _reshape_for_multihead(self, x: Array) -> Array:
45
+ """Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
46
+ seq_len, _ = x.shape
47
+ return x.reshape(seq_len, self.num_heads, self.head_dim)
48
+
49
+ def _combine_heads(self, x: Array) -> Array:
50
+ """Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
51
+ seq_len, _, _ = x.shape
52
+ return x.reshape(seq_len, -1)
53
+
54
+ def __call__(
55
+ self,
56
+ x: Array,
57
+ *,
58
+ key: PRNGKeyArray | None = None,
59
+ mask: Array | None = None,
60
+ cache: dict[str, Array] | None = None,
61
+ update_cache: bool = False,
62
+ ) -> Array | tuple[Array, dict[str, Array]]:
63
+ """Apply self-attention to the input.
64
+
65
+ Args:
66
+ x: Input tensor of shape (seq_len, embed_dim)
67
+ key: PRNGKey for dropout randomness
68
+ mask: Optional mask tensor of shape (seq_len, seq_len) or broadcastable
69
+ cache: Optional dictionary containing cached key and value tensors
70
+ update_cache: Whether to update the cache and return it
71
+
72
+ Returns:
73
+ If update_cache is False: Output tensor of shape (seq_len, embed_dim)
74
+ If update_cache is True: Tuple of (output tensor, updated cache)
75
+ """
76
+ chex.assert_rank(x, 2)
77
+
78
+ # Project inputs to queries, keys, and values
79
+ q = jax.vmap(self.q_proj)(x)
80
+
81
+ # Use cached key/value if provided and not updating cache
82
+ if cache is not None and not update_cache:
83
+ k = cache["k"]
84
+ v = cache["v"]
85
+ else:
86
+ k = jax.vmap(self.k_proj)(x)
87
+ v = jax.vmap(self.v_proj)(x)
88
+
89
+ # Update cache if needed
90
+ if update_cache:
91
+ if cache is None:
92
+ cache = {}
93
+ cache = {"k": k, "v": v}
94
+
95
+ # Reshape to multihead format
96
+ q = self._reshape_for_multihead(q)
97
+ k = self._reshape_for_multihead(k)
98
+ v = self._reshape_for_multihead(v)
99
+
100
+ # Apply dot product attention.
101
+ # Note that Apple Silicon struggles with this:
102
+ # https://github.com/jax-ml/jax/issues/20114
103
+ attn_output = jax.nn.dot_product_attention(
104
+ q,
105
+ k,
106
+ v,
107
+ mask=mask,
108
+ is_causal=self.causal and mask is None,
109
+ )
110
+
111
+ # Combine heads
112
+ attn_output = self._combine_heads(attn_output)
113
+
114
+ # Final projection
115
+ output = jax.vmap(self.output_proj)(attn_output)
116
+
117
+ if update_cache:
118
+ return output, cast(dict[str, Array], cache)
119
+ return output
120
+
121
+
122
+ class CrossAttentionBlock(eqx.Module):
123
+ """Cross-attention block using jax.nn.dot_product_attention."""
124
+
125
+ q_proj: eqx.nn.Linear
126
+ k_proj: eqx.nn.Linear
127
+ v_proj: eqx.nn.Linear
128
+ output_proj: eqx.nn.Linear
129
+ num_heads: int = eqx.static_field()
130
+ head_dim: int = eqx.static_field()
131
+
132
+ def __init__(
133
+ self,
134
+ embed_dim: int,
135
+ num_heads: int,
136
+ *,
137
+ key: PRNGKeyArray,
138
+ ) -> None:
139
+ keys = jax.random.split(key, 4)
140
+
141
+ self.num_heads = num_heads
142
+ self.head_dim = embed_dim // num_heads
143
+ assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
144
+
145
+ self.q_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[0])
146
+ self.k_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[1])
147
+ self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
148
+ self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
149
+
150
+ def _reshape_for_multihead(self, x: Array) -> Array:
151
+ """Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
152
+ seq_len, _ = x.shape
153
+ return x.reshape(seq_len, self.num_heads, self.head_dim)
154
+
155
+ def _combine_heads(self, x: Array) -> Array:
156
+ """Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
157
+ seq_len, _, _ = x.shape
158
+ return x.reshape(seq_len, -1)
159
+
160
+ def __call__(
161
+ self,
162
+ q_input: Array,
163
+ kv_input: Array,
164
+ *,
165
+ key: PRNGKeyArray | None = None,
166
+ mask: Array | None = None,
167
+ cache: dict[str, Array] | None = None,
168
+ update_cache: bool = False,
169
+ ) -> Array | tuple[Array, dict[str, Array]]:
170
+ """Apply cross-attention.
171
+
172
+ Args:
173
+ q_input: Query input tensor of shape (q_seq_len, embed_dim)
174
+ kv_input: Key/value input tensor of shape (kv_seq_len, embed_dim)
175
+ key: PRNGKey for dropout randomness
176
+ mask: Optional mask tensor
177
+ cache: Optional dictionary containing cached key and value tensors
178
+ update_cache: Whether to update the cache and return it
179
+
180
+ Returns:
181
+ If update_cache is False: Output tensor of shape (q_seq_len, embed_dim)
182
+ If update_cache is True: Tuple of (output tensor, updated cache)
183
+ """
184
+ chex.assert_rank(q_input, 2)
185
+ chex.assert_rank(kv_input, 2)
186
+
187
+ # Project inputs to queries, keys, and values
188
+ q = jax.vmap(self.q_proj)(q_input)
189
+
190
+ # Use cached key/value if provided and not updating cache
191
+ if cache is not None and not update_cache:
192
+ k = cache["k"]
193
+ v = cache["v"]
194
+ else:
195
+ k = jax.vmap(self.k_proj)(kv_input)
196
+ v = jax.vmap(self.v_proj)(kv_input)
197
+
198
+ # Update cache if needed
199
+ if update_cache:
200
+ if cache is None:
201
+ cache = {}
202
+ cache = {"k": k, "v": v}
203
+
204
+ # Reshape to multihead format
205
+ q = self._reshape_for_multihead(q)
206
+ k = self._reshape_for_multihead(k)
207
+ v = self._reshape_for_multihead(v)
208
+
209
+ # Apply dot product attention
210
+ attn_output = jax.nn.dot_product_attention(
211
+ q,
212
+ k,
213
+ v,
214
+ mask=mask,
215
+ is_causal=False,
216
+ )
217
+
218
+ # Combine heads
219
+ attn_output = self._combine_heads(attn_output)
220
+
221
+ # Final projection
222
+ output = jax.vmap(self.output_proj)(attn_output)
223
+
224
+ if update_cache:
225
+ return output, cast(dict[str, Array], cache)
226
+ return output
227
+
228
+
229
+ class TransformerBlock(eqx.Module):
230
+ self_attn: SelfAttentionBlock
231
+ cross_attn: CrossAttentionBlock | None
232
+ feed_forward: eqx.nn.MLP
233
+ layer_norm1: eqx.nn.LayerNorm
234
+ layer_norm2: eqx.nn.LayerNorm
235
+ layer_norm3: eqx.nn.LayerNorm | None
236
+
237
+ def __init__(
238
+ self,
239
+ embed_dim: int,
240
+ num_heads: int,
241
+ ff_dim: int,
242
+ *,
243
+ key: PRNGKeyArray,
244
+ causal: bool = False,
245
+ cross_attention: bool = False,
246
+ ) -> None:
247
+ keys = jax.random.split(key, 4)
248
+
249
+ self.self_attn = SelfAttentionBlock(
250
+ embed_dim=embed_dim,
251
+ num_heads=num_heads,
252
+ key=keys[0],
253
+ causal=causal,
254
+ )
255
+
256
+ if cross_attention:
257
+ self.cross_attn = CrossAttentionBlock(
258
+ embed_dim=embed_dim,
259
+ num_heads=num_heads,
260
+ key=keys[1],
261
+ )
262
+ self.layer_norm3 = eqx.nn.LayerNorm(embed_dim)
263
+ else:
264
+ self.cross_attn = None
265
+ self.layer_norm3 = None
266
+
267
+ self.layer_norm1 = eqx.nn.LayerNorm(embed_dim)
268
+ self.layer_norm2 = eqx.nn.LayerNorm(embed_dim)
269
+
270
+ self.feed_forward = eqx.nn.MLP(
271
+ in_size=embed_dim,
272
+ out_size=embed_dim,
273
+ width_size=ff_dim,
274
+ depth=1,
275
+ activation=jax.nn.gelu,
276
+ key=keys[2],
277
+ )
278
+
279
+ @overload
280
+ def __call__(
281
+ self,
282
+ x: Array,
283
+ *,
284
+ context: Array | None = None,
285
+ self_mask: Array | None = None,
286
+ cross_mask: Array | None = None,
287
+ key: PRNGKeyArray | None = None,
288
+ cache: dict[str, dict[str, Array]] | None = None,
289
+ update_cache: Literal[True],
290
+ ) -> tuple[Array, dict[str, dict[str, Array]]]: ...
291
+
292
+ @overload
293
+ def __call__(
294
+ self,
295
+ x: Array,
296
+ *,
297
+ context: Array | None = None,
298
+ self_mask: Array | None = None,
299
+ cross_mask: Array | None = None,
300
+ key: PRNGKeyArray | None = None,
301
+ cache: dict[str, dict[str, Array]] | None = None,
302
+ update_cache: Literal[False] = False,
303
+ ) -> Array: ...
304
+
305
+ def __call__(
306
+ self,
307
+ x: Array,
308
+ *,
309
+ context: Array | None = None,
310
+ self_mask: Array | None = None,
311
+ cross_mask: Array | None = None,
312
+ key: PRNGKeyArray | None = None,
313
+ cache: dict[str, dict[str, Array]] | None = None,
314
+ update_cache: bool = False,
315
+ ) -> Array | tuple[Array, dict[str, dict[str, Array]]]:
316
+ """Apply transformer block.
317
+
318
+ Args:
319
+ x: Input tensor
320
+ context: Optional context for cross-attention
321
+ self_mask: Mask for self-attention
322
+ cross_mask: Mask for cross-attention
323
+ key: Optional PRNG key for dropout
324
+ cache: Optional dictionary containing cached key and value tensors
325
+ update_cache: Whether to update the cache and return it
326
+
327
+ Returns:
328
+ If update_cache is False: Output tensor
329
+ If update_cache is True: Tuple of (output tensor, updated cache)
330
+ """
331
+ chex.assert_rank(x, 2)
332
+ if key is not None:
333
+ key1, key2 = jax.random.split(key)
334
+ else:
335
+ key1 = key2 = None
336
+
337
+ # Initialize cache if needed
338
+ updated_cache = {}
339
+ if cache is None:
340
+ cache = {}
341
+
342
+ # Self-attention block with pre-norm
343
+ norm_x = jax.vmap(self.layer_norm1)(x)
344
+
345
+ self_attn_cache = cache.get("self_attn")
346
+ if update_cache:
347
+ attn_output, self_attn_cache = self.self_attn(
348
+ norm_x, key=key1, mask=self_mask, cache=self_attn_cache, update_cache=True
349
+ )
350
+ updated_cache["self_attn"] = self_attn_cache
351
+ else:
352
+ attn_output = self.self_attn(norm_x, key=key1, mask=self_mask, cache=self_attn_cache)
353
+
354
+ x = x + attn_output
355
+
356
+ # Cross-attention block (if enabled) with pre-norm
357
+ if self.cross_attn is not None and context is not None:
358
+ assert self.layer_norm3 is not None
359
+
360
+ norm_x = jax.vmap(self.layer_norm3)(x)
361
+ cross_attn_cache = cache.get("cross_attn")
362
+
363
+ if update_cache:
364
+ cross_attn_output, cross_attn_cache = self.cross_attn(
365
+ norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache, update_cache=True
366
+ )
367
+ updated_cache["cross_attn"] = cross_attn_cache
368
+ else:
369
+ cross_attn_output = self.cross_attn(norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache)
370
+
371
+ x = x + cross_attn_output
372
+
373
+ # Feed-forward block with pre-norm
374
+ norm_x = jax.vmap(self.layer_norm2)(x)
375
+ ff_output = jax.vmap(self.feed_forward)(norm_x)
376
+ x = x + ff_output
377
+
378
+ if update_cache:
379
+ return x, updated_cache
380
+ return x
381
+
382
+
383
+ class Transformer(eqx.Module):
384
+ token_embedding: eqx.nn.Embedding
385
+ position_embedding: eqx.nn.Embedding | None
386
+ layers: list[TransformerBlock]
387
+ output_layer: eqx.nn.Linear | None
388
+ layer_norm: eqx.nn.LayerNorm
389
+ max_seq_len: int = eqx.static_field()
390
+ embed_dim: int = eqx.static_field()
391
+
392
+ def __init__(
393
+ self,
394
+ vocab_size: int,
395
+ embed_dim: int,
396
+ num_heads: int,
397
+ ff_dim: int,
398
+ num_layers: int,
399
+ max_seq_len: int,
400
+ output_size: int | None = None,
401
+ *,
402
+ key: PRNGKeyArray,
403
+ causal: bool = False,
404
+ cross_attention: bool = False,
405
+ use_absolute_position: bool = True,
406
+ ) -> None:
407
+ keys = jax.random.split(key, num_layers + 3)
408
+
409
+ self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
410
+
411
+ # Position embeddings can be disabled
412
+ if use_absolute_position:
413
+ self.position_embedding = eqx.nn.Embedding(max_seq_len, embed_dim, key=keys[1])
414
+ else:
415
+ self.position_embedding = None
416
+
417
+ self.layers = [
418
+ TransformerBlock(
419
+ embed_dim=embed_dim,
420
+ num_heads=num_heads,
421
+ ff_dim=ff_dim,
422
+ key=keys[i + 2],
423
+ causal=causal,
424
+ cross_attention=cross_attention,
425
+ )
426
+ for i in range(num_layers)
427
+ ]
428
+
429
+ self.layer_norm = eqx.nn.LayerNorm(embed_dim)
430
+
431
+ if output_size is not None:
432
+ self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[-1])
433
+ else:
434
+ self.output_layer = None
435
+
436
+ self.max_seq_len = max_seq_len
437
+ self.embed_dim = embed_dim
438
+
439
+ def _add_positional_embedding(self, x_embedded: Array, positions: Array | None = None) -> Array:
440
+ """Add positional embeddings to the token embeddings."""
441
+ if self.position_embedding is None:
442
+ return x_embedded
443
+
444
+ seq_len, _ = x_embedded.shape
445
+
446
+ if positions is None:
447
+ positions = jnp.arange(seq_len)
448
+ pos_embedded = jax.vmap(self.position_embedding)(positions)
449
+
450
+ return x_embedded + pos_embedded
451
+
452
+ def encode(
453
+ self,
454
+ x: Array,
455
+ mask: Array | None = None,
456
+ positions: Array | None = None,
457
+ key: PRNGKeyArray | None = None,
458
+ cache: dict[str, dict[str, dict[str, Array]]] | None = None,
459
+ update_cache: bool = False,
460
+ ) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
461
+ """Encode the input sequence.
462
+
463
+ Args:
464
+ x: Input token indices of shape (seq_len)
465
+ mask: Optional attention mask
466
+ positions: Optional positions
467
+ key: Optional PRNG key for dropout
468
+ cache: Optional dictionary containing cached key and value tensors
469
+ update_cache: Whether to update the cache and return it
470
+
471
+ Returns:
472
+ If update_cache is False: Encoded representation
473
+ If update_cache is True: Tuple of (encoded representation, updated cache)
474
+ """
475
+ # Token embedding
476
+ x_embedded = jax.vmap(self.token_embedding)(x)
477
+
478
+ # Add positional embedding
479
+ x_embedded = self._add_positional_embedding(x_embedded, positions)
480
+
481
+ # Initialize layer caches
482
+ if cache is None and update_cache:
483
+ cache = {f"layer_{i}": {} for i in range(len(self.layers))}
484
+
485
+ # Updated cache will be built if needed
486
+ updated_cache = {}
487
+
488
+ # Apply transformer layers
489
+ keys: Array | list[None] = [None] * len(self.layers)
490
+ if key is not None:
491
+ keys = jax.random.split(key, len(self.layers))
492
+
493
+ for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
494
+ layer_cache = None if cache is None else cache.get(f"layer_{i}")
495
+
496
+ if update_cache:
497
+ x_embedded, layer_updated_cache = layer.__call__(
498
+ x_embedded,
499
+ self_mask=mask,
500
+ key=layer_key,
501
+ cache=layer_cache,
502
+ update_cache=True,
503
+ )
504
+ updated_cache[f"layer_{i}"] = layer_updated_cache
505
+ else:
506
+ x_embedded = layer.__call__(
507
+ x_embedded,
508
+ self_mask=mask,
509
+ key=layer_key,
510
+ cache=layer_cache,
511
+ )
512
+
513
+ # Apply final layer norm
514
+ output = jax.vmap(self.layer_norm)(x_embedded)
515
+
516
+ if update_cache:
517
+ return output, updated_cache
518
+ return output
519
+
520
+ def decode(
521
+ self,
522
+ x: Array,
523
+ context: Array,
524
+ self_mask: Array | None = None,
525
+ cross_mask: Array | None = None,
526
+ positions: Array | None = None,
527
+ key: PRNGKeyArray | None = None,
528
+ cache: dict[str, dict[str, dict[str, Array]]] | None = None,
529
+ update_cache: bool = False,
530
+ ) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
531
+ """Decode with self-attention and cross-attention.
532
+
533
+ Args:
534
+ x: Input token indices
535
+ context: Context from encoder
536
+ self_mask: Optional self-attention mask
537
+ cross_mask: Optional cross-attention mask
538
+ positions: Optional positions
539
+ key: Optional PRNG key for dropout
540
+ cache: Optional dictionary containing cached key and value tensors
541
+ update_cache: Whether to update the cache and return it
542
+
543
+ Returns:
544
+ If update_cache is False: Decoded representation
545
+ If update_cache is True: Tuple of (decoded representation, updated cache)
546
+ """
547
+ # Token embedding
548
+ x_embedded = jax.vmap(lambda x_seq: jax.vmap(self.token_embedding)(x_seq))(x)
549
+
550
+ # Add positional embedding
551
+ x_embedded = self._add_positional_embedding(x_embedded, positions)
552
+
553
+ # Initialize layer caches
554
+ if cache is None and update_cache:
555
+ cache = {f"layer_{i}": {} for i in range(len(self.layers))}
556
+
557
+ # Updated cache will be built if needed
558
+ updated_cache = {}
559
+
560
+ # Apply transformer layers with cross-attention
561
+ keys: Array | list[None] = [None] * len(self.layers)
562
+ if key is not None:
563
+ keys = jax.random.split(key, len(self.layers))
564
+
565
+ for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
566
+ layer_cache = None if cache is None else cache.get(f"layer_{i}")
567
+
568
+ if update_cache:
569
+ x_embedded, layer_updated_cache = layer.__call__(
570
+ x_embedded,
571
+ context=context,
572
+ self_mask=self_mask,
573
+ cross_mask=cross_mask,
574
+ key=layer_key,
575
+ cache=layer_cache,
576
+ update_cache=True,
577
+ )
578
+ updated_cache[f"layer_{i}"] = layer_updated_cache
579
+ else:
580
+ x_embedded = layer(
581
+ x_embedded,
582
+ context=context,
583
+ self_mask=self_mask,
584
+ cross_mask=cross_mask,
585
+ key=layer_key,
586
+ cache=layer_cache,
587
+ )
588
+
589
+ # Apply final layer norm
590
+ output = jax.vmap(self.layer_norm)(x_embedded)
591
+
592
+ if update_cache:
593
+ return output, updated_cache
594
+ return output
595
+
596
+ @overload
597
+ def __call__(
598
+ self,
599
+ x: Array,
600
+ *,
601
+ mask: Array | None = None,
602
+ positions: Array | None = None,
603
+ key: PRNGKeyArray | None = None,
604
+ cache: dict[str, dict[str, dict[str, Array]]] | None = None,
605
+ update_cache: Literal[True],
606
+ ) -> tuple[Array, dict[str, dict[str, dict[str, Array]]]]: ...
607
+
608
+ @overload
609
+ def __call__(
610
+ self,
611
+ x: Array,
612
+ *,
613
+ mask: Array | None = None,
614
+ positions: Array | None = None,
615
+ key: PRNGKeyArray | None = None,
616
+ cache: dict[str, dict[str, dict[str, Array]]] | None = None,
617
+ update_cache: Literal[False] = False,
618
+ ) -> Array: ...
619
+
620
+ def __call__(
621
+ self,
622
+ x: Array,
623
+ *,
624
+ mask: Array | None = None,
625
+ positions: Array | None = None,
626
+ key: PRNGKeyArray | None = None,
627
+ cache: dict[str, dict[str, dict[str, Array]]] | None = None,
628
+ update_cache: bool = False,
629
+ ) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
630
+ """Forward pass for encoder-only or decoder-only transformers.
631
+
632
+ Args:
633
+ x: Input token indices of shape (seq_len)
634
+ mask: Optional attention mask
635
+ positions: Optional positions
636
+ key: Optional PRNG key for dropout
637
+ cache: Optional dictionary containing cached key and value tensors
638
+ update_cache: Whether to update the cache and return it
639
+
640
+ Returns:
641
+ If update_cache is False: Output representation
642
+ If update_cache is True: Tuple of (output representation, updated cache)
643
+ """
644
+ chex.assert_rank(x, 1)
645
+
646
+ if update_cache:
647
+ output, updated_cache = self.encode(
648
+ x, mask=mask, positions=positions, key=key, cache=cache, update_cache=True
649
+ )
650
+ else:
651
+ output = self.encode(x, mask=mask, positions=positions, key=key, cache=cache)
652
+
653
+ # Apply output layer if it exists
654
+ if self.output_layer is not None:
655
+ output = jax.vmap(self.output_layer)(output)
656
+
657
+ if update_cache:
658
+ return output, updated_cache
659
+ return output
660
+
661
+ def predict_sequence(self, x_seq: Array) -> Array:
662
+ return self(x=x_seq)
663
+
664
+ def generate_sequence(
665
+ self,
666
+ prompt_seq: Array,
667
+ max_len: int,
668
+ temperature: float = 1.0,
669
+ top_k: int | None = None,
670
+ key: PRNGKeyArray | None = None,
671
+ ) -> Array:
672
+ """Generate a sequence autoregressively with KV caching.
673
+
674
+ Args:
675
+ prompt_seq: Input token indices of shape (prompt_len,)
676
+ max_len: Maximum length of generated sequence
677
+ temperature: Sampling temperature
678
+ top_k: Optional top-k sampling parameter
679
+ key: PRNG key for sampling
680
+
681
+ Returns:
682
+ Generated sequence of shape (prompt_len + max_len,)
683
+ """
684
+ if key is None:
685
+ key = jax.random.PRNGKey(0)
686
+
687
+ prompt_len = prompt_seq.shape[0]
688
+ sequence = prompt_seq
689
+
690
+ # Create causal mask for generation
691
+ causal_mask = jnp.tril(jnp.ones((self.max_seq_len, self.max_seq_len), dtype=jnp.bool_))
692
+
693
+ # Initialize cache with the prompt
694
+ _, cache = self(x=prompt_seq, mask=causal_mask[:prompt_len, :prompt_len], update_cache=True)
695
+
696
+ # Define decode step function (for clarity)
697
+ def decode_step(seq: Array, pos: int, cur_cache: dict, rng: PRNGKeyArray) -> tuple[Array, dict, PRNGKeyArray]:
698
+ # Get the next position and last token
699
+ pos_tensor = jnp.array([pos])
700
+ last_token = seq[-1:]
701
+
702
+ # Get logits for next token
703
+ rng, subrng = jax.random.split(rng)
704
+ logits, new_cache = self(
705
+ x=last_token,
706
+ positions=pos_tensor,
707
+ key=subrng,
708
+ cache=cur_cache,
709
+ update_cache=True,
710
+ )
711
+
712
+ # Extract final logits and apply temperature
713
+ logits = logits[-1] / temperature
714
+
715
+ # Apply top-k sampling if specified
716
+ if top_k is not None:
717
+ top_logits, top_indices = jax.lax.top_k(logits, top_k)
718
+ logits = jnp.full_like(logits, float("-inf"))
719
+ logits = logits.at[top_indices].set(top_logits)
720
+
721
+ # Sample next token
722
+ rng, subrng = jax.random.split(rng)
723
+ next_token = jax.random.categorical(subrng, logits[None, ...])[0]
724
+
725
+ # Add token to sequence
726
+ new_seq = jnp.concatenate([seq, next_token[None]], axis=0)
727
+ return new_seq, new_cache, rng
728
+
729
+ # Generate tokens one by one
730
+ for _ in range(max_len):
731
+ # Break if max sequence length reached
732
+ if sequence.shape[0] >= self.max_seq_len:
733
+ break
734
+
735
+ # Decode next token
736
+ sequence, cache, key = decode_step(seq=sequence, pos=sequence.shape[0] - 1, cur_cache=cache, rng=key)
737
+
738
+ return sequence
xax/task/logger.py CHANGED
@@ -526,7 +526,7 @@ class LoggerImpl(ABC):
526
526
  Returns:
527
527
  If the logger should log the current step.
528
528
  """
529
- elapsed_time = state.elapsed_time_s.item() if state.phase == "train" else state.valid_elapsed_time_s.item()
529
+ elapsed_time = state.elapsed_time_s.item()
530
530
  return self.tickers[state.phase].tick(elapsed_time)
531
531
 
532
532
 
xax/task/mixins/train.py CHANGED
@@ -121,7 +121,7 @@ class ValidStepTimer:
121
121
  self.last_valid_step = state.num_steps.item()
122
122
 
123
123
  def __call__(self, state: State) -> bool:
124
- if state.num_steps < self.valid_first_n_steps and state.num_valid_steps < self.valid_first_n_steps:
124
+ if state.num_steps < self.valid_first_n_steps:
125
125
  return True
126
126
 
127
127
  if self.last_valid_time is None or self.last_valid_step is None:
@@ -130,18 +130,15 @@ class ValidStepTimer:
130
130
 
131
131
  # Step-based validation.
132
132
  valid_every_n_steps = self.valid_every_n_steps
133
- if valid_every_n_steps is not None and (
134
- state.num_steps >= valid_every_n_steps + self.last_valid_step
135
- or state.num_valid_steps >= valid_every_n_steps + self.last_valid_step
136
- ):
133
+ if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
137
134
  self._reset(state)
138
135
  return True
139
136
 
140
137
  # Time-based validation.
141
138
  valid_every_n_seconds = self.valid_every_n_seconds
142
- if valid_every_n_seconds is not None and (
143
- state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
144
- or state.valid_elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
139
+ if (
140
+ valid_every_n_seconds is not None
141
+ and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
145
142
  ):
146
143
  self._reset(state)
147
144
  return True
@@ -149,10 +146,7 @@ class ValidStepTimer:
149
146
  # Time-based validation for first validation step.
150
147
  if self.first_valid_step_flag:
151
148
  valid_first_n_seconds = self.valid_first_n_seconds
152
- if valid_first_n_seconds is not None and (
153
- state.elapsed_time_s.item() >= valid_first_n_seconds
154
- or state.valid_elapsed_time_s.item() >= valid_first_n_seconds
155
- ):
149
+ if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
156
150
  self._reset(state)
157
151
  self.first_valid_step_flag = False
158
152
  return True
@@ -777,12 +771,12 @@ class TrainMixin(
777
771
  self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
778
772
 
779
773
  state = state.replace(
780
- num_valid_steps=state.num_valid_steps + 1,
781
- num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
774
+ num_steps=state.num_steps + 1,
775
+ num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
782
776
  )
783
777
 
784
778
  state = state.replace(
785
- valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
779
+ elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
786
780
  )
787
781
 
788
782
  with ContextTimer() as timer:
@@ -882,7 +876,7 @@ class TrainMixin(
882
876
  key, model_key = jax.random.split(key)
883
877
  models, optimizers, opt_states, state = self.load_initial_state(model_key, load_optimizer=True)
884
878
  logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
885
- logger.info("Optimizer size: %s", f"{get_pytree_param_count(optimizers):,}")
879
+ logger.info("Optimizer size: %s", f"{get_pytree_param_count(opt_states):,}")
886
880
 
887
881
  state = self.on_training_start(state)
888
882
 
xax/utils/experiments.py CHANGED
@@ -111,8 +111,8 @@ class StateTimer:
111
111
 
112
112
  def step(self, state: State) -> None:
113
113
  cur_time = time.time()
114
- num_steps = int((state.num_steps if state.phase == "train" else state.num_valid_steps).item())
115
- num_samples = int((state.num_samples if state.phase == "train" else state.num_valid_samples).item())
114
+ num_steps = int(state.num_steps.item())
115
+ num_samples = int(state.num_samples.item())
116
116
  self.step_timer.step(num_steps, cur_time)
117
117
  self.sample_timer.step(num_samples, cur_time)
118
118
  self.iter_timer.step(cur_time)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.22
3
+ Version: 0.3.0
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=Wh6x1Nohprb7ZxS_Y1aHPSo2xD7rAFSbmz31HLRl5og,15293
1
+ xax/__init__.py,sha256=4JFBksXZsEjxfY7yHj4cyGf7vI6plavry36LU1Kq-oY,15652
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -6,8 +6,9 @@ xax/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/cli/edit_config.py,sha256=LQUIlOS6hvPZyVEaMme3FP-62M0BKQPYavCwVDWuBLw,2600
7
7
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
9
- xax/core/state.py,sha256=KsNMnM_RgsZ2Ntc2pp4Fi6zG4rZb_89-kqmyGxDvyRg,4974
9
+ xax/core/state.py,sha256=F9Tj3FfCw8zFKaDEoEGiThZE2ntYEtzNjnBX3pQ1g60,3826
10
10
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ xax/nn/attention.py,sha256=0essK90OO3x9FxnUqU0DhufwXKRMN41zMtRCki5iKzQ,24742
11
12
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
12
13
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
13
14
  xax/nn/geom.py,sha256=A7WPefMvgwUNReZC7_HX1GmvHPASyghbaXaKsuhwDrE,7382
@@ -17,7 +18,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
18
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
19
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
20
  xax/task/base.py,sha256=TYANmjNcce4_V5ZSYLnE91PXRn7Nn0nT7hN8plW_Au0,8117
20
- xax/task/logger.py,sha256=W_BpluYvQai1lh1dDCAj-2_mWUC1buhwJncHygDffjc,41125
21
+ xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
21
22
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
23
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
23
24
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -41,10 +42,10 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
42
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
42
43
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
44
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=eueQc6P15Gkc9_lU7sp7fIHt4qrqOmhc4Xt6pCYZPkw,33636
45
+ xax/task/mixins/train.py,sha256=TZatz5QwTfrNhQTiO2IqrmQY9P4Lay6FAD2VsQpWa54,33245
45
46
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
47
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
- xax/utils/experiments.py,sha256=bj8BftSHT3fFzfiJ0Co0WvqWo0rUS8kQnQYpVvH8FTM,29942
48
+ xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
48
49
  xax/utils/jax.py,sha256=6cP95-rcjkRt1fefkZWJQhJhH0uUYWJB3w4NP1-aDp0,10136
49
50
  xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
50
51
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
@@ -58,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
59
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
60
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
60
61
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.2.22.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.2.22.dist-info/METADATA,sha256=FtyVr4ve7FYrZCkDUWqneCAszYI-QSNs_ZTPrdbXUxg,1247
63
- xax-0.2.22.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
64
- xax-0.2.22.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
65
- xax-0.2.22.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
66
- xax-0.2.22.dist-info/RECORD,,
62
+ xax-0.3.0.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
+ xax-0.3.0.dist-info/METADATA,sha256=HjHuF55MnVyLkWEyzmfftbaPlAVsd7qGbjrOgWioEw8,1246
64
+ xax-0.3.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
65
+ xax-0.3.0.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
+ xax-0.3.0.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
+ xax-0.3.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.6.0)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5