xax 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.10"
15
+ __version__ = "0.3.11"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
xax/nn/distributions.py CHANGED
@@ -12,12 +12,16 @@ __all__ = [
12
12
  "MixtureOfGaussians",
13
13
  ]
14
14
 
15
+ import math
15
16
  from abc import ABC, abstractmethod
16
17
 
17
18
  import jax
18
19
  import jax.numpy as jnp
19
20
  from jaxtyping import Array, PRNGKeyArray
20
21
 
22
+ STD_CLIP = 1e-6
23
+ LOGIT_CLIP = math.log(1e4)
24
+
21
25
 
22
26
  class Distribution(ABC):
23
27
  @abstractmethod
@@ -34,87 +38,91 @@ class Distribution(ABC):
34
38
 
35
39
 
36
40
  class Categorical(Distribution):
37
- def __init__(self, logits_n: Array) -> None:
38
- self.logits_n = logits_n
41
+ def __init__(self, logits_nc: Array, logit_clip: float = LOGIT_CLIP) -> None:
42
+ """Initialize a categorical distribution.
43
+
44
+ Args:
45
+ logits_nc: Array of shape (..., n_categories) containing logits
46
+ logit_clip: Clipping value for logits
47
+ """
48
+ self.logits_nc = jnp.clip(logits_nc, -logit_clip, logit_clip)
39
49
 
40
50
  @property
41
51
  def num_categories(self) -> int:
42
- return self.logits_n.shape[-1]
52
+ return self.logits_nc.shape[-1]
43
53
 
44
- def log_prob(self, x: Array) -> Array:
45
- """Compute log probability for specific categories.
46
-
47
- Args:
48
- x: Array of category indices
49
-
50
- Returns:
51
- Log probabilities for the given categories
52
- """
53
- log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
54
- # Use advanced indexing to get the log probabilities for the given categories
55
- return log_probs[x]
54
+ def log_prob(self, x_n: Array) -> Array:
55
+ log_probs_n = jax.nn.log_softmax(self.logits_nc, axis=-1)
56
+ return log_probs_n[x_n]
56
57
 
57
58
  def sample(self, key: PRNGKeyArray) -> Array:
58
- return jax.random.categorical(key, self.logits_n, axis=-1)
59
+ return jax.random.categorical(key, self.logits_nc, axis=-1)
59
60
 
60
61
  def mode(self) -> Array:
61
- return self.logits_n.argmax(axis=-1)
62
+ return self.logits_nc.argmax(axis=-1)
62
63
 
63
64
  def entropy(self) -> Array:
64
- """Compute entropy of the categorical distribution."""
65
- probs = jax.nn.softmax(self.logits_n, axis=-1)
66
- log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
65
+ probs = jax.nn.softmax(self.logits_nc, axis=-1)
66
+ log_probs = jax.nn.log_softmax(self.logits_nc, axis=-1)
67
67
  return -jnp.sum(probs * log_probs, axis=-1)
68
68
 
69
69
 
70
70
  class Normal(Distribution):
71
- def __init__(self, loc: Array, scale: Array) -> None:
72
- self.loc = loc
73
- self.scale = scale
71
+ def __init__(self, loc_n: Array, scale_n: Array, std_clip: float = STD_CLIP) -> None:
72
+ """Initialize a normal distribution.
73
+
74
+ Args:
75
+ loc_n: Mean of the distribution
76
+ scale_n: Standard deviation of the distribution
77
+ std_clip: Minimum standard deviation
78
+ """
79
+ self.loc_n = loc_n
80
+ self.scale_n = jnp.clip(scale_n, min=std_clip)
74
81
 
75
82
  def log_prob(self, x: Array) -> Array:
76
- return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale) - (x - self.loc) ** 2 / (2 * self.scale**2)
83
+ return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale_n) - (x - self.loc_n) ** 2 / (2 * self.scale_n**2)
77
84
 
78
85
  def sample(self, key: PRNGKeyArray) -> Array:
79
- return self.loc + self.scale * jax.random.normal(key, self.loc.shape)
86
+ return self.loc_n + self.scale_n * jax.random.normal(key, self.loc_n.shape)
80
87
 
81
88
  def mode(self) -> Array:
82
- return self.loc
89
+ return self.loc_n
83
90
 
84
91
  def entropy(self) -> Array:
85
- return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale)
92
+ return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale_n)
86
93
 
87
94
 
88
95
  class MixtureOfGaussians(Distribution):
89
- def __init__(self, means_nm: Array, stds_nm: Array, logits_nm: Array) -> None:
96
+ def __init__(
97
+ self,
98
+ means_nm: Array,
99
+ stds_nm: Array,
100
+ logits_nm: Array,
101
+ std_clip: float = STD_CLIP,
102
+ logit_clip: float = LOGIT_CLIP,
103
+ ) -> None:
90
104
  """Initialize a mixture of Gaussians.
91
105
 
92
106
  Args:
93
107
  means_nm: Array of shape (..., n_components) containing means
94
108
  stds_nm: Array of shape (..., n_components) containing standard deviations
95
109
  logits_nm: Array of shape (..., n_components) containing mixing logits
110
+ std_clip: Minimum standard deviation
111
+ logit_clip: Clipping value for logits
96
112
  """
97
113
  self.means_nm = means_nm
98
- self.stds_nm = stds_nm
99
- self.logits_nm = logits_nm
100
-
101
- def log_prob(self, x: Array) -> Array:
102
- """Compute log probability of the mixture.
114
+ self.stds_nm = jnp.clip(stds_nm, min=std_clip)
115
+ self.logits_nm = jnp.clip(logits_nm, -logit_clip, logit_clip)
103
116
 
104
- Args:
105
- x: Array of shape (...,) containing values to evaluate
106
-
107
- Returns:
108
- Log probabilities of shape (...,)
109
- """
117
+ def log_prob(self, x_n: Array) -> Array:
110
118
  # Expand x to match component dimensions
111
- x_expanded = x[..., None] # Shape: (..., 1)
119
+ x_n_expanded = x_n[..., None] # Shape: (..., 1)
112
120
 
113
121
  # Compute log probabilities for each component
114
122
  component_log_probs = (
115
123
  -0.5 * jnp.log(2 * jnp.pi)
116
124
  - jnp.log(self.stds_nm)
117
- - (x_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
125
+ - (x_n_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
118
126
  )
119
127
 
120
128
  # Compute mixing weights
@@ -123,16 +131,7 @@ class MixtureOfGaussians(Distribution):
123
131
  # Combine using log-sum-exp trick for numerical stability
124
132
  return jax.scipy.special.logsumexp(component_log_probs + mixing_logits, axis=-1)
125
133
 
126
- def sample(self, key: PRNGKeyArray) -> Array:
127
- """Sample from the mixture of Gaussians.
128
-
129
- Args:
130
- key: PRNG key
131
-
132
- Returns:
133
- Samples of shape (...,) where ... are the batch dimensions
134
- """
135
- # Sample component indices
134
+ def sample(self, key: PRNGKeyArray) -> Array: # Sample component indices
136
135
  component_key, sample_key = jax.random.split(key)
137
136
  component_indices = jax.random.categorical(component_key, self.logits_nm, axis=-1)
138
137
 
@@ -153,8 +152,8 @@ class MixtureOfGaussians(Distribution):
153
152
  noise = jax.random.normal(sample_key, selected_means.shape)
154
153
 
155
154
  # Reshape back to original batch shape
156
- samples = selected_means + selected_stds * noise
157
- return samples.reshape(batch_shape)
155
+ samples_n = selected_means + selected_stds * noise
156
+ return samples_n.reshape(batch_shape)
158
157
 
159
158
  def mode(self) -> Array:
160
159
  """Return the mode of the mixture (approximate - returns mean of highest weight component)."""
xax/task/mixins/train.py CHANGED
@@ -177,7 +177,6 @@ class TrainConfig(
177
177
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
178
178
  updates_per_step: int = field(1, help="Number of updates to perform per step")
179
179
  random_seed: int = field(1337, help="Random seed for the task")
180
- global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
181
180
 
182
181
 
183
182
  Config = TypeVar("Config", bound=TrainConfig)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.10
3
+ Version: 0.3.11
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=lSwyrPTof_BZ-pyPNhNICJnCZMN9i2sJ-Ii3S_vY_28,16666
1
+ xax/__init__.py,sha256=Kd9-a62JICqpaZqb0WaJMz7qC5uHYghOHZsnCb3EC6Q,16666
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -9,7 +9,7 @@ xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
9
9
  xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
10
10
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
12
- xax/nn/distributions.py,sha256=096IDvoJ0ZA4SqcfgNSmrICsGcsKVcTAh0Vl6SwN3-o,6343
12
+ xax/nn/distributions.py,sha256=b251blOwdxkWUOaYjOuqcR_HNMfm9I8Aq9EDxoIxHVw,6519
13
13
  xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
14
14
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
15
15
  xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
@@ -43,7 +43,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
43
43
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
44
44
  xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
45
45
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
46
- xax/task/mixins/train.py,sha256=_kDpifLi1arSuT0ssFhBV0axpvLlQG3a97pohya0Eqc,32908
46
+ xax/task/mixins/train.py,sha256=hwAR_G1kgvhXgrE5ZRNL4Jn-Teflx65_1bdk6aULXEg,32814
47
47
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
48
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
49
49
  xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
@@ -60,9 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
60
60
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
61
61
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
62
62
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
63
- xax-0.3.10.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
- xax-0.3.10.dist-info/METADATA,sha256=oQMGYjsfYxMmw0A60qE15yda_G-0YG5RNl17tboR1f0,1247
65
- xax-0.3.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
- xax-0.3.10.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
- xax-0.3.10.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
- xax-0.3.10.dist-info/RECORD,,
63
+ xax-0.3.11.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
+ xax-0.3.11.dist-info/METADATA,sha256=FaS2TIfJ5ExcZYXP1KBugCPal5jexn_HZ5oFQCDvq9g,1247
65
+ xax-0.3.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
+ xax-0.3.11.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
+ xax-0.3.11.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
+ xax-0.3.11.dist-info/RECORD,,
File without changes