xax 0.3.7__py3-none-any.whl → 0.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.7"
15
+ __version__ = "0.3.8"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -31,6 +31,10 @@ __all__ = [
31
31
  "TransformerBlock",
32
32
  "TransformerCache",
33
33
  "TransformerStack",
34
+ "Categorical",
35
+ "Distribution",
36
+ "MixtureOfGaussians",
37
+ "Normal",
34
38
  "FourierEmbeddings",
35
39
  "IdentityPositionalEmbeddings",
36
40
  "LearnedPositionalEmbeddings",
@@ -219,6 +223,10 @@ NAME_MAP: dict[str, str] = {
219
223
  "TransformerBlock": "nn.attention",
220
224
  "TransformerCache": "nn.attention",
221
225
  "TransformerStack": "nn.attention",
226
+ "Categorical": "nn.distributions",
227
+ "Distribution": "nn.distributions",
228
+ "MixtureOfGaussians": "nn.distributions",
229
+ "Normal": "nn.distributions",
222
230
  "FourierEmbeddings": "nn.embeddings",
223
231
  "IdentityPositionalEmbeddings": "nn.embeddings",
224
232
  "LearnedPositionalEmbeddings": "nn.embeddings",
@@ -405,6 +413,7 @@ if IMPORT_ALL or TYPE_CHECKING:
405
413
  TransformerCache,
406
414
  TransformerStack,
407
415
  )
416
+ from xax.nn.distributions import Categorical, Distribution, MixtureOfGaussians, Normal
408
417
  from xax.nn.embeddings import (
409
418
  EmbeddingKind,
410
419
  FourierEmbeddings,
@@ -0,0 +1,181 @@
1
+ """Defines some probability distribution helper functions.
2
+
3
+ In general, it is preferrable to use Distrax or another library, but we wanted
4
+ to have a simple interface of our own so that we can quickly upgrade Jax
5
+ versions (since Distrax is tied pretty closely to Tensorflow).
6
+ """
7
+
8
+ __all__ = [
9
+ "Distribution",
10
+ "Categorical",
11
+ "Normal",
12
+ "MixtureOfGaussians",
13
+ ]
14
+
15
+ from abc import ABC, abstractmethod
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from jaxtyping import Array, PRNGKeyArray
20
+
21
+
22
+ class Distribution(ABC):
23
+ @abstractmethod
24
+ def log_prob(self, x: Array) -> Array: ...
25
+
26
+ @abstractmethod
27
+ def sample(self, key: PRNGKeyArray) -> Array: ...
28
+
29
+ @abstractmethod
30
+ def mode(self) -> Array: ...
31
+
32
+ @abstractmethod
33
+ def entropy(self) -> Array: ...
34
+
35
+
36
+ class Categorical(Distribution):
37
+ def __init__(self, logits_n: Array) -> None:
38
+ self.logits_n = logits_n
39
+
40
+ @property
41
+ def num_categories(self) -> int:
42
+ return self.logits_n.shape[-1]
43
+
44
+ def log_prob(self, x: Array) -> Array:
45
+ """Compute log probability for specific categories.
46
+
47
+ Args:
48
+ x: Array of category indices
49
+
50
+ Returns:
51
+ Log probabilities for the given categories
52
+ """
53
+ log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
54
+ # Use advanced indexing to get the log probabilities for the given categories
55
+ return log_probs[x]
56
+
57
+ def sample(self, key: PRNGKeyArray) -> Array:
58
+ return jax.random.categorical(key, self.logits_n, axis=-1)
59
+
60
+ def mode(self) -> Array:
61
+ return self.logits_n.argmax(axis=-1)
62
+
63
+ def entropy(self) -> Array:
64
+ """Compute entropy of the categorical distribution."""
65
+ probs = jax.nn.softmax(self.logits_n, axis=-1)
66
+ log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
67
+ return -jnp.sum(probs * log_probs, axis=-1)
68
+
69
+
70
+ class Normal(Distribution):
71
+ def __init__(self, loc: Array, scale: Array) -> None:
72
+ self.loc = loc
73
+ self.scale = scale
74
+
75
+ def log_prob(self, x: Array) -> Array:
76
+ return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale) - (x - self.loc) ** 2 / (2 * self.scale**2)
77
+
78
+ def sample(self, key: PRNGKeyArray) -> Array:
79
+ return self.loc + self.scale * jax.random.normal(key, self.loc.shape)
80
+
81
+ def mode(self) -> Array:
82
+ return self.loc
83
+
84
+ def entropy(self) -> Array:
85
+ return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale)
86
+
87
+
88
+ class MixtureOfGaussians(Distribution):
89
+ def __init__(self, means_nm: Array, stds_nm: Array, logits_nm: Array) -> None:
90
+ """Initialize a mixture of Gaussians.
91
+
92
+ Args:
93
+ means_nm: Array of shape (..., n_components) containing means
94
+ stds_nm: Array of shape (..., n_components) containing standard deviations
95
+ logits_nm: Array of shape (..., n_components) containing mixing logits
96
+ """
97
+ self.means_nm = means_nm
98
+ self.stds_nm = stds_nm
99
+ self.logits_nm = logits_nm
100
+
101
+ def log_prob(self, x: Array) -> Array:
102
+ """Compute log probability of the mixture.
103
+
104
+ Args:
105
+ x: Array of shape (...,) containing values to evaluate
106
+
107
+ Returns:
108
+ Log probabilities of shape (...,)
109
+ """
110
+ # Expand x to match component dimensions
111
+ x_expanded = x[..., None] # Shape: (..., 1)
112
+
113
+ # Compute log probabilities for each component
114
+ component_log_probs = (
115
+ -0.5 * jnp.log(2 * jnp.pi)
116
+ - jnp.log(self.stds_nm)
117
+ - (x_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
118
+ )
119
+
120
+ # Compute mixing weights
121
+ mixing_logits = jax.nn.log_softmax(self.logits_nm, axis=-1)
122
+
123
+ # Combine using log-sum-exp trick for numerical stability
124
+ return jax.scipy.special.logsumexp(component_log_probs + mixing_logits, axis=-1)
125
+
126
+ def sample(self, key: PRNGKeyArray) -> Array:
127
+ """Sample from the mixture of Gaussians.
128
+
129
+ Args:
130
+ key: PRNG key
131
+
132
+ Returns:
133
+ Samples of shape (...,) where ... are the batch dimensions
134
+ """
135
+ # Sample component indices
136
+ component_key, sample_key = jax.random.split(key)
137
+ component_indices = jax.random.categorical(component_key, self.logits_nm, axis=-1)
138
+
139
+ # Sample from selected components using advanced indexing
140
+ # We need to handle the case where we have batch dimensions
141
+ batch_shape = self.means_nm.shape[:-1] # All dimensions except the last (components)
142
+
143
+ # Reshape for easier indexing
144
+ means_flat = self.means_nm.reshape(-1, self.means_nm.shape[-1])
145
+ stds_flat = self.stds_nm.reshape(-1, self.stds_nm.shape[-1])
146
+ indices_flat = component_indices.reshape(-1)
147
+
148
+ # Get selected means and stds
149
+ selected_means = means_flat[jnp.arange(len(indices_flat)), indices_flat]
150
+ selected_stds = stds_flat[jnp.arange(len(indices_flat)), indices_flat]
151
+
152
+ # Generate random noise
153
+ noise = jax.random.normal(sample_key, selected_means.shape)
154
+
155
+ # Reshape back to original batch shape
156
+ samples = selected_means + selected_stds * noise
157
+ return samples.reshape(batch_shape)
158
+
159
+ def mode(self) -> Array:
160
+ """Return the mode of the mixture (approximate - returns mean of highest weight component)."""
161
+ mixing_weights = jax.nn.softmax(self.logits_nm, axis=-1)
162
+ max_weight_idx = jnp.argmax(mixing_weights, axis=-1)
163
+
164
+ # Use advanced indexing to get the means of the highest weight components
165
+ batch_shape = self.means_nm.shape[:-1]
166
+ means_flat = self.means_nm.reshape(-1, self.means_nm.shape[-1])
167
+ indices_flat = max_weight_idx.reshape(-1)
168
+
169
+ selected_means = means_flat[jnp.arange(len(indices_flat)), indices_flat]
170
+ return selected_means.reshape(batch_shape)
171
+
172
+ def entropy(self) -> Array:
173
+ """Compute entropy of the mixture (approximate)."""
174
+ mixing_weights = jax.nn.softmax(self.logits_nm, axis=-1)
175
+ component_entropies = jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.stds_nm)
176
+
177
+ # Weighted sum of component entropies plus mixing entropy
178
+ weighted_entropies = jnp.sum(mixing_weights * component_entropies, axis=-1)
179
+ mixing_entropy = -jnp.sum(mixing_weights * jnp.log(mixing_weights + 1e-8), axis=-1)
180
+
181
+ return weighted_entropies + mixing_entropy
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.7
3
+ Version: 0.3.8
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=YCDjLRwliJCyEmNFC56PNQXV9Vn9Fr13VJS_am4h3To,16336
1
+ xax/__init__.py,sha256=5NGaVm9X36LhG-Tl1hc7Lk1SmnTZvyu8G1iFDixpqLc,16665
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -9,6 +9,7 @@ xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
9
9
  xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
10
10
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
12
+ xax/nn/distributions.py,sha256=096IDvoJ0ZA4SqcfgNSmrICsGcsKVcTAh0Vl6SwN3-o,6343
12
13
  xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
13
14
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
14
15
  xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
@@ -59,9 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
59
60
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
61
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
61
62
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
62
- xax-0.3.7.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
- xax-0.3.7.dist-info/METADATA,sha256=8Zb0pvTJOjrCHK7giM2MbhlGCPREQewJK3GgRDQNWY0,1246
64
- xax-0.3.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
- xax-0.3.7.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
- xax-0.3.7.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
- xax-0.3.7.dist-info/RECORD,,
63
+ xax-0.3.8.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
+ xax-0.3.8.dist-info/METADATA,sha256=d4UVJYHBKGAJTdC8G4IHt9kI44lbexOWIiZnkICd0pM,1246
65
+ xax-0.3.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
+ xax-0.3.8.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
+ xax-0.3.8.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
+ xax-0.3.8.dist-info/RECORD,,
File without changes