xax 0.3.10__py3-none-any.whl → 0.3.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.10"
15
+ __version__ = "0.3.12"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -138,6 +138,7 @@ __all__ = [
138
138
  "worker_chunk",
139
139
  "profile",
140
140
  "compute_nan_ratio",
141
+ "diff_pytree",
141
142
  "flatten_array",
142
143
  "flatten_pytree",
143
144
  "get_pytree_mapping",
@@ -330,6 +331,7 @@ NAME_MAP: dict[str, str] = {
330
331
  "worker_chunk": "utils.numpy",
331
332
  "profile": "utils.profile",
332
333
  "compute_nan_ratio": "utils.pytree",
334
+ "diff_pytree": "utils.pytree",
333
335
  "flatten_array": "utils.pytree",
334
336
  "flatten_pytree": "utils.pytree",
335
337
  "get_pytree_mapping": "utils.pytree",
@@ -413,7 +415,12 @@ if IMPORT_ALL or TYPE_CHECKING:
413
415
  TransformerCache,
414
416
  TransformerStack,
415
417
  )
416
- from xax.nn.distributions import Categorical, Distribution, MixtureOfGaussians, Normal
418
+ from xax.nn.distributions import (
419
+ Categorical,
420
+ Distribution,
421
+ MixtureOfGaussians,
422
+ Normal,
423
+ )
417
424
  from xax.nn.embeddings import (
418
425
  EmbeddingKind,
419
426
  FourierEmbeddings,
@@ -518,6 +525,7 @@ if IMPORT_ALL or TYPE_CHECKING:
518
525
  from xax.utils.profile import profile
519
526
  from xax.utils.pytree import (
520
527
  compute_nan_ratio,
528
+ diff_pytree,
521
529
  flatten_array,
522
530
  flatten_pytree,
523
531
  get_pytree_mapping,
xax/nn/distributions.py CHANGED
@@ -18,6 +18,9 @@ import jax
18
18
  import jax.numpy as jnp
19
19
  from jaxtyping import Array, PRNGKeyArray
20
20
 
21
+ STD_CLIP = 1e-6
22
+ LOGIT_CLIP = 6.0
23
+
21
24
 
22
25
  class Distribution(ABC):
23
26
  @abstractmethod
@@ -34,87 +37,91 @@ class Distribution(ABC):
34
37
 
35
38
 
36
39
  class Categorical(Distribution):
37
- def __init__(self, logits_n: Array) -> None:
38
- self.logits_n = logits_n
40
+ def __init__(self, logits_nc: Array, logit_clip: float = LOGIT_CLIP) -> None:
41
+ """Initialize a categorical distribution.
42
+
43
+ Args:
44
+ logits_nc: Array of shape (..., n_categories) containing logits
45
+ logit_clip: Clipping value for logits
46
+ """
47
+ self.logits_nc = jnp.clip(logits_nc, -logit_clip, logit_clip)
39
48
 
40
49
  @property
41
50
  def num_categories(self) -> int:
42
- return self.logits_n.shape[-1]
51
+ return self.logits_nc.shape[-1]
43
52
 
44
- def log_prob(self, x: Array) -> Array:
45
- """Compute log probability for specific categories.
46
-
47
- Args:
48
- x: Array of category indices
49
-
50
- Returns:
51
- Log probabilities for the given categories
52
- """
53
- log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
54
- # Use advanced indexing to get the log probabilities for the given categories
55
- return log_probs[x]
53
+ def log_prob(self, x_n: Array) -> Array:
54
+ log_probs_n = jax.nn.log_softmax(self.logits_nc, axis=-1)
55
+ return log_probs_n[x_n]
56
56
 
57
57
  def sample(self, key: PRNGKeyArray) -> Array:
58
- return jax.random.categorical(key, self.logits_n, axis=-1)
58
+ return jax.random.categorical(key, self.logits_nc, axis=-1)
59
59
 
60
60
  def mode(self) -> Array:
61
- return self.logits_n.argmax(axis=-1)
61
+ return self.logits_nc.argmax(axis=-1)
62
62
 
63
63
  def entropy(self) -> Array:
64
- """Compute entropy of the categorical distribution."""
65
- probs = jax.nn.softmax(self.logits_n, axis=-1)
66
- log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
64
+ probs = jax.nn.softmax(self.logits_nc, axis=-1)
65
+ log_probs = jax.nn.log_softmax(self.logits_nc, axis=-1)
67
66
  return -jnp.sum(probs * log_probs, axis=-1)
68
67
 
69
68
 
70
69
  class Normal(Distribution):
71
- def __init__(self, loc: Array, scale: Array) -> None:
72
- self.loc = loc
73
- self.scale = scale
70
+ def __init__(self, loc_n: Array, scale_n: Array, std_clip: float = STD_CLIP) -> None:
71
+ """Initialize a normal distribution.
72
+
73
+ Args:
74
+ loc_n: Mean of the distribution
75
+ scale_n: Standard deviation of the distribution
76
+ std_clip: Minimum standard deviation
77
+ """
78
+ self.loc_n = loc_n
79
+ self.scale_n = jnp.clip(scale_n, min=std_clip)
74
80
 
75
81
  def log_prob(self, x: Array) -> Array:
76
- return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale) - (x - self.loc) ** 2 / (2 * self.scale**2)
82
+ return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale_n) - (x - self.loc_n) ** 2 / (2 * self.scale_n**2)
77
83
 
78
84
  def sample(self, key: PRNGKeyArray) -> Array:
79
- return self.loc + self.scale * jax.random.normal(key, self.loc.shape)
85
+ return self.loc_n + self.scale_n * jax.random.normal(key, self.loc_n.shape)
80
86
 
81
87
  def mode(self) -> Array:
82
- return self.loc
88
+ return self.loc_n
83
89
 
84
90
  def entropy(self) -> Array:
85
- return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale)
91
+ return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale_n)
86
92
 
87
93
 
88
94
  class MixtureOfGaussians(Distribution):
89
- def __init__(self, means_nm: Array, stds_nm: Array, logits_nm: Array) -> None:
95
+ def __init__(
96
+ self,
97
+ means_nm: Array,
98
+ stds_nm: Array,
99
+ logits_nm: Array,
100
+ std_clip: float = STD_CLIP,
101
+ logit_clip: float = LOGIT_CLIP,
102
+ ) -> None:
90
103
  """Initialize a mixture of Gaussians.
91
104
 
92
105
  Args:
93
106
  means_nm: Array of shape (..., n_components) containing means
94
107
  stds_nm: Array of shape (..., n_components) containing standard deviations
95
108
  logits_nm: Array of shape (..., n_components) containing mixing logits
109
+ std_clip: Minimum standard deviation
110
+ logit_clip: Clipping value for logits
96
111
  """
97
112
  self.means_nm = means_nm
98
- self.stds_nm = stds_nm
99
- self.logits_nm = logits_nm
100
-
101
- def log_prob(self, x: Array) -> Array:
102
- """Compute log probability of the mixture.
113
+ self.stds_nm = jnp.clip(stds_nm, min=std_clip)
114
+ self.logits_nm = jnp.clip(logits_nm, -logit_clip, logit_clip)
103
115
 
104
- Args:
105
- x: Array of shape (...,) containing values to evaluate
106
-
107
- Returns:
108
- Log probabilities of shape (...,)
109
- """
116
+ def log_prob(self, x_n: Array) -> Array:
110
117
  # Expand x to match component dimensions
111
- x_expanded = x[..., None] # Shape: (..., 1)
118
+ x_n_expanded = x_n[..., None] # Shape: (..., 1)
112
119
 
113
120
  # Compute log probabilities for each component
114
121
  component_log_probs = (
115
122
  -0.5 * jnp.log(2 * jnp.pi)
116
123
  - jnp.log(self.stds_nm)
117
- - (x_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
124
+ - (x_n_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
118
125
  )
119
126
 
120
127
  # Compute mixing weights
@@ -123,16 +130,7 @@ class MixtureOfGaussians(Distribution):
123
130
  # Combine using log-sum-exp trick for numerical stability
124
131
  return jax.scipy.special.logsumexp(component_log_probs + mixing_logits, axis=-1)
125
132
 
126
- def sample(self, key: PRNGKeyArray) -> Array:
127
- """Sample from the mixture of Gaussians.
128
-
129
- Args:
130
- key: PRNG key
131
-
132
- Returns:
133
- Samples of shape (...,) where ... are the batch dimensions
134
- """
135
- # Sample component indices
133
+ def sample(self, key: PRNGKeyArray) -> Array: # Sample component indices
136
134
  component_key, sample_key = jax.random.split(key)
137
135
  component_indices = jax.random.categorical(component_key, self.logits_nm, axis=-1)
138
136
 
@@ -153,8 +151,8 @@ class MixtureOfGaussians(Distribution):
153
151
  noise = jax.random.normal(sample_key, selected_means.shape)
154
152
 
155
153
  # Reshape back to original batch shape
156
- samples = selected_means + selected_stds * noise
157
- return samples.reshape(batch_shape)
154
+ samples_n = selected_means + selected_stds * noise
155
+ return samples_n.reshape(batch_shape)
158
156
 
159
157
  def mode(self) -> Array:
160
158
  """Return the mode of the mixture (approximate - returns mean of highest weight component)."""
xax/task/mixins/train.py CHANGED
@@ -177,7 +177,6 @@ class TrainConfig(
177
177
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
178
178
  updates_per_step: int = field(1, help="Number of updates to perform per step")
179
179
  random_seed: int = field(1337, help="Random seed for the task")
180
- global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
181
180
 
182
181
 
183
182
  Config = TypeVar("Config", bound=TrainConfig)
xax/utils/pytree.py CHANGED
@@ -1,12 +1,15 @@
1
1
  """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
2
 
3
+ from dataclasses import fields, is_dataclass
3
4
  from typing import Mapping, Sequence, TypeVar
4
5
 
5
6
  import chex
6
7
  import equinox as eqx
7
8
  import jax
8
9
  import jax.numpy as jnp
10
+ import numpy as np
9
11
  from jax import Array
12
+ from jax.core import get_aval
10
13
  from jaxtyping import PRNGKeyArray, PyTree
11
14
 
12
15
  T = TypeVar("T")
@@ -258,18 +261,79 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
258
261
  def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
259
262
  leaves: dict[str, Array] = {}
260
263
 
261
- def _get_str(thing: PyTree) -> str:
262
- if isinstance(thing, str):
263
- return thing
264
- if isinstance(thing, Sequence):
265
- return "/".join(_get_str(x) for x in thing)
266
- if isinstance(thing, Mapping):
267
- return "/".join(f"{_get_str(k)}:{_get_str(v)}" for k, v in thing.items())
268
- return str(thing)
269
-
270
264
  def _get_leaf(path: tuple, x: PyTree) -> None:
271
265
  if isinstance(x, jnp.ndarray):
272
- leaves[_get_str(path)] = x
266
+ leaves[jax.tree_util.keystr(path, simple=True, separator="/")] = x
273
267
 
274
268
  jax.tree.map_with_path(_get_leaf, pytree)
275
269
  return leaves
270
+
271
+
272
+ def diff_pytree(tree_a: PyTree, tree_b: PyTree, prefix: str = "") -> list[str]:
273
+ diffs = []
274
+
275
+ # Handles dataclasses.
276
+ if is_dataclass(tree_a) and is_dataclass(tree_b):
277
+ for field in fields(tree_a):
278
+ attr_a, attr_b = getattr(tree_a, field.name), getattr(tree_b, field.name)
279
+ diffs.extend(diff_pytree(attr_a, attr_b, prefix + f"{field.name}."))
280
+ return diffs
281
+
282
+ # Handle dict-like objects
283
+ elif isinstance(tree_a, Mapping) and isinstance(tree_b, Mapping):
284
+ if type(tree_a) is not type(tree_b):
285
+ diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
286
+ return diffs
287
+ keys_a, keys_b = set(tree_a.keys()), set(tree_b.keys())
288
+ for k in keys_a - keys_b:
289
+ diffs.append(f"{prefix}{k}: present in A only")
290
+ for k in keys_b - keys_a:
291
+ diffs.append(f"{prefix}{k}: present in B only")
292
+ for k in keys_a & keys_b:
293
+ diffs.extend(diff_pytree(tree_a[k], tree_b[k], prefix + f"{k}."))
294
+ return diffs
295
+
296
+ # Handle tuple/list
297
+ elif isinstance(tree_a, Sequence) and isinstance(tree_b, Sequence):
298
+ if type(tree_a) is not type(tree_b):
299
+ diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
300
+ return diffs
301
+ if len(tree_a) != len(tree_b):
302
+ diffs.append(f"{prefix}: different lengths {len(tree_a)} vs {len(tree_b)}")
303
+ for i, (a_i, b_i) in enumerate(zip(tree_a, tree_b, strict=True)):
304
+ diffs.extend(diff_pytree(a_i, b_i, prefix + f"[{i}]."))
305
+ return diffs
306
+
307
+ # Handles basic types.
308
+ elif isinstance(tree_a, (int, float, bool, str, type(None), np.number, np.bool, bytes)):
309
+ if tree_a != tree_b:
310
+ diffs.append(f"{prefix}: {tree_a!r} vs {tree_b!r}")
311
+ return diffs
312
+
313
+ # Handles Numpy arrays.
314
+ elif isinstance(tree_a, np.ndarray) and isinstance(tree_b, np.ndarray):
315
+ if tree_a.shape != tree_b.shape:
316
+ diffs.append(f"{prefix}: shape {tree_a.shape} vs {tree_b.shape}")
317
+ if tree_a.dtype != tree_b.dtype:
318
+ diffs.append(f"{prefix}: dtype {tree_a.dtype} vs {tree_b.dtype}")
319
+ return diffs
320
+
321
+ # Handle arrays (check shape/dtype)
322
+ elif isinstance(tree_a, jnp.ndarray) and isinstance(tree_b, jnp.ndarray):
323
+ if tree_a.shape != tree_b.shape:
324
+ diffs.append(f"{prefix}: shape {tree_a.shape} vs {tree_b.shape}")
325
+ if tree_a.dtype != tree_b.dtype:
326
+ diffs.append(f"{prefix}: dtype {tree_a.dtype} vs {tree_b.dtype}")
327
+ aval_a = get_aval(tree_a)
328
+ aval_b = get_aval(tree_b)
329
+ if aval_a != aval_b: # pyright: ignore[reportAttributeAccessIssue]
330
+ diffs.append(f"{prefix}: aval {aval_a} vs {aval_b}")
331
+ return diffs
332
+
333
+ # Handle mismatched types
334
+ elif type(tree_a) is not type(tree_b):
335
+ diffs.append(f"{prefix}: type {type(tree_a)} vs {type(tree_b)}")
336
+ return diffs
337
+
338
+ else:
339
+ raise ValueError(f"Unknown type: {type(tree_a)}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.10
3
+ Version: 0.3.12
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=lSwyrPTof_BZ-pyPNhNICJnCZMN9i2sJ-Ii3S_vY_28,16666
1
+ xax/__init__.py,sha256=HXD6tR7Bz1b5ImFyRyR1kAok-dx5g8eBDpO_lCIP8rk,16782
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -9,7 +9,7 @@ xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
9
9
  xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
10
10
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
12
- xax/nn/distributions.py,sha256=096IDvoJ0ZA4SqcfgNSmrICsGcsKVcTAh0Vl6SwN3-o,6343
12
+ xax/nn/distributions.py,sha256=6YOjyiPOC7XLDaMYpFNBlLCu3eLgDAeqIg9FoKfYLL4,6497
13
13
  xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
14
14
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
15
15
  xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
@@ -43,7 +43,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
43
43
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
44
44
  xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
45
45
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
46
- xax/task/mixins/train.py,sha256=_kDpifLi1arSuT0ssFhBV0axpvLlQG3a97pohya0Eqc,32908
46
+ xax/task/mixins/train.py,sha256=hwAR_G1kgvhXgrE5ZRNL4Jn-Teflx65_1bdk6aULXEg,32814
47
47
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
48
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
49
49
  xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
@@ -52,7 +52,7 @@ xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
52
52
  xax/utils/logging.py,sha256=Kkyma_LJXqrN2HTQ214gRP_9ih3_bKk115MWC60lQWM,6656
53
53
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
54
54
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
55
- xax/utils/pytree.py,sha256=w8Ab2LmJdQ8e1FxKF0xWWaOak09Mhu44ZcOeUR6uGFA,9889
55
+ xax/utils/pytree.py,sha256=e8T5DY0ZhPcbvS3EuOsac0Oprra46lN05WEIhVN-3V0,12670
56
56
  xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
57
57
  xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
58
58
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -60,9 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
60
60
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
61
61
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
62
62
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
63
- xax-0.3.10.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
- xax-0.3.10.dist-info/METADATA,sha256=oQMGYjsfYxMmw0A60qE15yda_G-0YG5RNl17tboR1f0,1247
65
- xax-0.3.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
- xax-0.3.10.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
- xax-0.3.10.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
- xax-0.3.10.dist-info/RECORD,,
63
+ xax-0.3.12.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
+ xax-0.3.12.dist-info/METADATA,sha256=RACxHJ_iF4r0BTTTgyTI1ExYF_-aXRWrsq3NlQC7l9A,1247
65
+ xax-0.3.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
+ xax-0.3.12.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
+ xax-0.3.12.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
+ xax-0.3.12.dist-info/RECORD,,
File without changes