xax 0.3.7__py3-none-any.whl → 0.3.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +10 -1
- xax/nn/distributions.py +181 -0
- {xax-0.3.7.dist-info → xax-0.3.8.dist-info}/METADATA +1 -1
- {xax-0.3.7.dist-info → xax-0.3.8.dist-info}/RECORD +8 -7
- {xax-0.3.7.dist-info → xax-0.3.8.dist-info}/WHEEL +0 -0
- {xax-0.3.7.dist-info → xax-0.3.8.dist-info}/entry_points.txt +0 -0
- {xax-0.3.7.dist-info → xax-0.3.8.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.7.dist-info → xax-0.3.8.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.8"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -31,6 +31,10 @@ __all__ = [
|
|
31
31
|
"TransformerBlock",
|
32
32
|
"TransformerCache",
|
33
33
|
"TransformerStack",
|
34
|
+
"Categorical",
|
35
|
+
"Distribution",
|
36
|
+
"MixtureOfGaussians",
|
37
|
+
"Normal",
|
34
38
|
"FourierEmbeddings",
|
35
39
|
"IdentityPositionalEmbeddings",
|
36
40
|
"LearnedPositionalEmbeddings",
|
@@ -219,6 +223,10 @@ NAME_MAP: dict[str, str] = {
|
|
219
223
|
"TransformerBlock": "nn.attention",
|
220
224
|
"TransformerCache": "nn.attention",
|
221
225
|
"TransformerStack": "nn.attention",
|
226
|
+
"Categorical": "nn.distributions",
|
227
|
+
"Distribution": "nn.distributions",
|
228
|
+
"MixtureOfGaussians": "nn.distributions",
|
229
|
+
"Normal": "nn.distributions",
|
222
230
|
"FourierEmbeddings": "nn.embeddings",
|
223
231
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
224
232
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -405,6 +413,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
405
413
|
TransformerCache,
|
406
414
|
TransformerStack,
|
407
415
|
)
|
416
|
+
from xax.nn.distributions import Categorical, Distribution, MixtureOfGaussians, Normal
|
408
417
|
from xax.nn.embeddings import (
|
409
418
|
EmbeddingKind,
|
410
419
|
FourierEmbeddings,
|
xax/nn/distributions.py
ADDED
@@ -0,0 +1,181 @@
|
|
1
|
+
"""Defines some probability distribution helper functions.
|
2
|
+
|
3
|
+
In general, it is preferrable to use Distrax or another library, but we wanted
|
4
|
+
to have a simple interface of our own so that we can quickly upgrade Jax
|
5
|
+
versions (since Distrax is tied pretty closely to Tensorflow).
|
6
|
+
"""
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"Distribution",
|
10
|
+
"Categorical",
|
11
|
+
"Normal",
|
12
|
+
"MixtureOfGaussians",
|
13
|
+
]
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
|
17
|
+
import jax
|
18
|
+
import jax.numpy as jnp
|
19
|
+
from jaxtyping import Array, PRNGKeyArray
|
20
|
+
|
21
|
+
|
22
|
+
class Distribution(ABC):
|
23
|
+
@abstractmethod
|
24
|
+
def log_prob(self, x: Array) -> Array: ...
|
25
|
+
|
26
|
+
@abstractmethod
|
27
|
+
def sample(self, key: PRNGKeyArray) -> Array: ...
|
28
|
+
|
29
|
+
@abstractmethod
|
30
|
+
def mode(self) -> Array: ...
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def entropy(self) -> Array: ...
|
34
|
+
|
35
|
+
|
36
|
+
class Categorical(Distribution):
|
37
|
+
def __init__(self, logits_n: Array) -> None:
|
38
|
+
self.logits_n = logits_n
|
39
|
+
|
40
|
+
@property
|
41
|
+
def num_categories(self) -> int:
|
42
|
+
return self.logits_n.shape[-1]
|
43
|
+
|
44
|
+
def log_prob(self, x: Array) -> Array:
|
45
|
+
"""Compute log probability for specific categories.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
x: Array of category indices
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
Log probabilities for the given categories
|
52
|
+
"""
|
53
|
+
log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
|
54
|
+
# Use advanced indexing to get the log probabilities for the given categories
|
55
|
+
return log_probs[x]
|
56
|
+
|
57
|
+
def sample(self, key: PRNGKeyArray) -> Array:
|
58
|
+
return jax.random.categorical(key, self.logits_n, axis=-1)
|
59
|
+
|
60
|
+
def mode(self) -> Array:
|
61
|
+
return self.logits_n.argmax(axis=-1)
|
62
|
+
|
63
|
+
def entropy(self) -> Array:
|
64
|
+
"""Compute entropy of the categorical distribution."""
|
65
|
+
probs = jax.nn.softmax(self.logits_n, axis=-1)
|
66
|
+
log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
|
67
|
+
return -jnp.sum(probs * log_probs, axis=-1)
|
68
|
+
|
69
|
+
|
70
|
+
class Normal(Distribution):
|
71
|
+
def __init__(self, loc: Array, scale: Array) -> None:
|
72
|
+
self.loc = loc
|
73
|
+
self.scale = scale
|
74
|
+
|
75
|
+
def log_prob(self, x: Array) -> Array:
|
76
|
+
return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale) - (x - self.loc) ** 2 / (2 * self.scale**2)
|
77
|
+
|
78
|
+
def sample(self, key: PRNGKeyArray) -> Array:
|
79
|
+
return self.loc + self.scale * jax.random.normal(key, self.loc.shape)
|
80
|
+
|
81
|
+
def mode(self) -> Array:
|
82
|
+
return self.loc
|
83
|
+
|
84
|
+
def entropy(self) -> Array:
|
85
|
+
return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale)
|
86
|
+
|
87
|
+
|
88
|
+
class MixtureOfGaussians(Distribution):
|
89
|
+
def __init__(self, means_nm: Array, stds_nm: Array, logits_nm: Array) -> None:
|
90
|
+
"""Initialize a mixture of Gaussians.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
means_nm: Array of shape (..., n_components) containing means
|
94
|
+
stds_nm: Array of shape (..., n_components) containing standard deviations
|
95
|
+
logits_nm: Array of shape (..., n_components) containing mixing logits
|
96
|
+
"""
|
97
|
+
self.means_nm = means_nm
|
98
|
+
self.stds_nm = stds_nm
|
99
|
+
self.logits_nm = logits_nm
|
100
|
+
|
101
|
+
def log_prob(self, x: Array) -> Array:
|
102
|
+
"""Compute log probability of the mixture.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
x: Array of shape (...,) containing values to evaluate
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
Log probabilities of shape (...,)
|
109
|
+
"""
|
110
|
+
# Expand x to match component dimensions
|
111
|
+
x_expanded = x[..., None] # Shape: (..., 1)
|
112
|
+
|
113
|
+
# Compute log probabilities for each component
|
114
|
+
component_log_probs = (
|
115
|
+
-0.5 * jnp.log(2 * jnp.pi)
|
116
|
+
- jnp.log(self.stds_nm)
|
117
|
+
- (x_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
|
118
|
+
)
|
119
|
+
|
120
|
+
# Compute mixing weights
|
121
|
+
mixing_logits = jax.nn.log_softmax(self.logits_nm, axis=-1)
|
122
|
+
|
123
|
+
# Combine using log-sum-exp trick for numerical stability
|
124
|
+
return jax.scipy.special.logsumexp(component_log_probs + mixing_logits, axis=-1)
|
125
|
+
|
126
|
+
def sample(self, key: PRNGKeyArray) -> Array:
|
127
|
+
"""Sample from the mixture of Gaussians.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
key: PRNG key
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Samples of shape (...,) where ... are the batch dimensions
|
134
|
+
"""
|
135
|
+
# Sample component indices
|
136
|
+
component_key, sample_key = jax.random.split(key)
|
137
|
+
component_indices = jax.random.categorical(component_key, self.logits_nm, axis=-1)
|
138
|
+
|
139
|
+
# Sample from selected components using advanced indexing
|
140
|
+
# We need to handle the case where we have batch dimensions
|
141
|
+
batch_shape = self.means_nm.shape[:-1] # All dimensions except the last (components)
|
142
|
+
|
143
|
+
# Reshape for easier indexing
|
144
|
+
means_flat = self.means_nm.reshape(-1, self.means_nm.shape[-1])
|
145
|
+
stds_flat = self.stds_nm.reshape(-1, self.stds_nm.shape[-1])
|
146
|
+
indices_flat = component_indices.reshape(-1)
|
147
|
+
|
148
|
+
# Get selected means and stds
|
149
|
+
selected_means = means_flat[jnp.arange(len(indices_flat)), indices_flat]
|
150
|
+
selected_stds = stds_flat[jnp.arange(len(indices_flat)), indices_flat]
|
151
|
+
|
152
|
+
# Generate random noise
|
153
|
+
noise = jax.random.normal(sample_key, selected_means.shape)
|
154
|
+
|
155
|
+
# Reshape back to original batch shape
|
156
|
+
samples = selected_means + selected_stds * noise
|
157
|
+
return samples.reshape(batch_shape)
|
158
|
+
|
159
|
+
def mode(self) -> Array:
|
160
|
+
"""Return the mode of the mixture (approximate - returns mean of highest weight component)."""
|
161
|
+
mixing_weights = jax.nn.softmax(self.logits_nm, axis=-1)
|
162
|
+
max_weight_idx = jnp.argmax(mixing_weights, axis=-1)
|
163
|
+
|
164
|
+
# Use advanced indexing to get the means of the highest weight components
|
165
|
+
batch_shape = self.means_nm.shape[:-1]
|
166
|
+
means_flat = self.means_nm.reshape(-1, self.means_nm.shape[-1])
|
167
|
+
indices_flat = max_weight_idx.reshape(-1)
|
168
|
+
|
169
|
+
selected_means = means_flat[jnp.arange(len(indices_flat)), indices_flat]
|
170
|
+
return selected_means.reshape(batch_shape)
|
171
|
+
|
172
|
+
def entropy(self) -> Array:
|
173
|
+
"""Compute entropy of the mixture (approximate)."""
|
174
|
+
mixing_weights = jax.nn.softmax(self.logits_nm, axis=-1)
|
175
|
+
component_entropies = jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.stds_nm)
|
176
|
+
|
177
|
+
# Weighted sum of component entropies plus mixing entropy
|
178
|
+
weighted_entropies = jnp.sum(mixing_weights * component_entropies, axis=-1)
|
179
|
+
mixing_entropy = -jnp.sum(mixing_weights * jnp.log(mixing_weights + 1e-8), axis=-1)
|
180
|
+
|
181
|
+
return weighted_entropies + mixing_entropy
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=5NGaVm9X36LhG-Tl1hc7Lk1SmnTZvyu8G1iFDixpqLc,16665
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -9,6 +9,7 @@ xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
|
9
9
|
xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
|
12
|
+
xax/nn/distributions.py,sha256=096IDvoJ0ZA4SqcfgNSmrICsGcsKVcTAh0Vl6SwN3-o,6343
|
12
13
|
xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
|
13
14
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
14
15
|
xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
|
@@ -59,9 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
59
60
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
60
61
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
61
62
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
62
|
-
xax-0.3.
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
63
|
+
xax-0.3.8.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
64
|
+
xax-0.3.8.dist-info/METADATA,sha256=d4UVJYHBKGAJTdC8G4IHt9kI44lbexOWIiZnkICd0pM,1246
|
65
|
+
xax-0.3.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
66
|
+
xax-0.3.8.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
67
|
+
xax-0.3.8.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
68
|
+
xax-0.3.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|