xax 0.3.9__tar.gz → 0.3.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {xax-0.3.9/xax.egg-info → xax-0.3.11}/PKG-INFO +1 -1
  2. {xax-0.3.9 → xax-0.3.11}/xax/__init__.py +1 -1
  3. {xax-0.3.9 → xax-0.3.11}/xax/nn/distributions.py +52 -53
  4. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/train.py +0 -1
  5. {xax-0.3.9 → xax-0.3.11}/xax/utils/pytree.py +11 -4
  6. {xax-0.3.9 → xax-0.3.11/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.3.9 → xax-0.3.11}/LICENSE +0 -0
  8. {xax-0.3.9 → xax-0.3.11}/MANIFEST.in +0 -0
  9. {xax-0.3.9 → xax-0.3.11}/README.md +0 -0
  10. {xax-0.3.9 → xax-0.3.11}/pyproject.toml +0 -0
  11. {xax-0.3.9 → xax-0.3.11}/setup.cfg +0 -0
  12. {xax-0.3.9 → xax-0.3.11}/setup.py +0 -0
  13. {xax-0.3.9 → xax-0.3.11}/xax/cli/__init__.py +0 -0
  14. {xax-0.3.9 → xax-0.3.11}/xax/cli/edit_config.py +0 -0
  15. {xax-0.3.9 → xax-0.3.11}/xax/core/__init__.py +0 -0
  16. {xax-0.3.9 → xax-0.3.11}/xax/core/conf.py +0 -0
  17. {xax-0.3.9 → xax-0.3.11}/xax/core/state.py +0 -0
  18. {xax-0.3.9 → xax-0.3.11}/xax/nn/__init__.py +0 -0
  19. {xax-0.3.9 → xax-0.3.11}/xax/nn/attention.py +0 -0
  20. {xax-0.3.9 → xax-0.3.11}/xax/nn/embeddings.py +0 -0
  21. {xax-0.3.9 → xax-0.3.11}/xax/nn/functions.py +0 -0
  22. {xax-0.3.9 → xax-0.3.11}/xax/nn/geom.py +0 -0
  23. {xax-0.3.9 → xax-0.3.11}/xax/nn/losses.py +0 -0
  24. {xax-0.3.9 → xax-0.3.11}/xax/nn/metrics.py +0 -0
  25. {xax-0.3.9 → xax-0.3.11}/xax/nn/parallel.py +0 -0
  26. {xax-0.3.9 → xax-0.3.11}/xax/nn/ssm.py +0 -0
  27. {xax-0.3.9 → xax-0.3.11}/xax/py.typed +0 -0
  28. {xax-0.3.9 → xax-0.3.11}/xax/requirements-dev.txt +0 -0
  29. {xax-0.3.9 → xax-0.3.11}/xax/requirements.txt +0 -0
  30. {xax-0.3.9 → xax-0.3.11}/xax/task/__init__.py +0 -0
  31. {xax-0.3.9 → xax-0.3.11}/xax/task/base.py +0 -0
  32. {xax-0.3.9 → xax-0.3.11}/xax/task/launchers/__init__.py +0 -0
  33. {xax-0.3.9 → xax-0.3.11}/xax/task/launchers/base.py +0 -0
  34. {xax-0.3.9 → xax-0.3.11}/xax/task/launchers/cli.py +0 -0
  35. {xax-0.3.9 → xax-0.3.11}/xax/task/launchers/single_process.py +0 -0
  36. {xax-0.3.9 → xax-0.3.11}/xax/task/logger.py +0 -0
  37. {xax-0.3.9 → xax-0.3.11}/xax/task/loggers/__init__.py +0 -0
  38. {xax-0.3.9 → xax-0.3.11}/xax/task/loggers/callback.py +0 -0
  39. {xax-0.3.9 → xax-0.3.11}/xax/task/loggers/json.py +0 -0
  40. {xax-0.3.9 → xax-0.3.11}/xax/task/loggers/state.py +0 -0
  41. {xax-0.3.9 → xax-0.3.11}/xax/task/loggers/stdout.py +0 -0
  42. {xax-0.3.9 → xax-0.3.11}/xax/task/loggers/tensorboard.py +0 -0
  43. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/__init__.py +0 -0
  44. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/artifacts.py +0 -0
  45. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/checkpointing.py +0 -0
  46. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/compile.py +0 -0
  47. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/cpu_stats.py +0 -0
  48. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/data_loader.py +0 -0
  49. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/gpu_stats.py +0 -0
  50. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/logger.py +0 -0
  51. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/process.py +0 -0
  52. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/runnable.py +0 -0
  53. {xax-0.3.9 → xax-0.3.11}/xax/task/mixins/step_wrapper.py +0 -0
  54. {xax-0.3.9 → xax-0.3.11}/xax/task/script.py +0 -0
  55. {xax-0.3.9 → xax-0.3.11}/xax/task/task.py +0 -0
  56. {xax-0.3.9 → xax-0.3.11}/xax/utils/__init__.py +0 -0
  57. {xax-0.3.9 → xax-0.3.11}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.3.9 → xax-0.3.11}/xax/utils/data/collate.py +0 -0
  59. {xax-0.3.9 → xax-0.3.11}/xax/utils/debugging.py +0 -0
  60. {xax-0.3.9 → xax-0.3.11}/xax/utils/experiments.py +0 -0
  61. {xax-0.3.9 → xax-0.3.11}/xax/utils/jax.py +0 -0
  62. {xax-0.3.9 → xax-0.3.11}/xax/utils/jaxpr.py +0 -0
  63. {xax-0.3.9 → xax-0.3.11}/xax/utils/logging.py +0 -0
  64. {xax-0.3.9 → xax-0.3.11}/xax/utils/numpy.py +0 -0
  65. {xax-0.3.9 → xax-0.3.11}/xax/utils/profile.py +0 -0
  66. {xax-0.3.9 → xax-0.3.11}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.3.9 → xax-0.3.11}/xax/utils/text.py +0 -0
  68. {xax-0.3.9 → xax-0.3.11}/xax/utils/types/__init__.py +0 -0
  69. {xax-0.3.9 → xax-0.3.11}/xax/utils/types/frozen_dict.py +0 -0
  70. {xax-0.3.9 → xax-0.3.11}/xax/utils/types/hashable_array.py +0 -0
  71. {xax-0.3.9 → xax-0.3.11}/xax.egg-info/SOURCES.txt +0 -0
  72. {xax-0.3.9 → xax-0.3.11}/xax.egg-info/dependency_links.txt +0 -0
  73. {xax-0.3.9 → xax-0.3.11}/xax.egg-info/entry_points.txt +0 -0
  74. {xax-0.3.9 → xax-0.3.11}/xax.egg-info/requires.txt +0 -0
  75. {xax-0.3.9 → xax-0.3.11}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.9
3
+ Version: 0.3.11
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.9"
15
+ __version__ = "0.3.11"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -12,12 +12,16 @@ __all__ = [
12
12
  "MixtureOfGaussians",
13
13
  ]
14
14
 
15
+ import math
15
16
  from abc import ABC, abstractmethod
16
17
 
17
18
  import jax
18
19
  import jax.numpy as jnp
19
20
  from jaxtyping import Array, PRNGKeyArray
20
21
 
22
+ STD_CLIP = 1e-6
23
+ LOGIT_CLIP = math.log(1e4)
24
+
21
25
 
22
26
  class Distribution(ABC):
23
27
  @abstractmethod
@@ -34,87 +38,91 @@ class Distribution(ABC):
34
38
 
35
39
 
36
40
  class Categorical(Distribution):
37
- def __init__(self, logits_n: Array) -> None:
38
- self.logits_n = logits_n
41
+ def __init__(self, logits_nc: Array, logit_clip: float = LOGIT_CLIP) -> None:
42
+ """Initialize a categorical distribution.
43
+
44
+ Args:
45
+ logits_nc: Array of shape (..., n_categories) containing logits
46
+ logit_clip: Clipping value for logits
47
+ """
48
+ self.logits_nc = jnp.clip(logits_nc, -logit_clip, logit_clip)
39
49
 
40
50
  @property
41
51
  def num_categories(self) -> int:
42
- return self.logits_n.shape[-1]
52
+ return self.logits_nc.shape[-1]
43
53
 
44
- def log_prob(self, x: Array) -> Array:
45
- """Compute log probability for specific categories.
46
-
47
- Args:
48
- x: Array of category indices
49
-
50
- Returns:
51
- Log probabilities for the given categories
52
- """
53
- log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
54
- # Use advanced indexing to get the log probabilities for the given categories
55
- return log_probs[x]
54
+ def log_prob(self, x_n: Array) -> Array:
55
+ log_probs_n = jax.nn.log_softmax(self.logits_nc, axis=-1)
56
+ return log_probs_n[x_n]
56
57
 
57
58
  def sample(self, key: PRNGKeyArray) -> Array:
58
- return jax.random.categorical(key, self.logits_n, axis=-1)
59
+ return jax.random.categorical(key, self.logits_nc, axis=-1)
59
60
 
60
61
  def mode(self) -> Array:
61
- return self.logits_n.argmax(axis=-1)
62
+ return self.logits_nc.argmax(axis=-1)
62
63
 
63
64
  def entropy(self) -> Array:
64
- """Compute entropy of the categorical distribution."""
65
- probs = jax.nn.softmax(self.logits_n, axis=-1)
66
- log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
65
+ probs = jax.nn.softmax(self.logits_nc, axis=-1)
66
+ log_probs = jax.nn.log_softmax(self.logits_nc, axis=-1)
67
67
  return -jnp.sum(probs * log_probs, axis=-1)
68
68
 
69
69
 
70
70
  class Normal(Distribution):
71
- def __init__(self, loc: Array, scale: Array) -> None:
72
- self.loc = loc
73
- self.scale = scale
71
+ def __init__(self, loc_n: Array, scale_n: Array, std_clip: float = STD_CLIP) -> None:
72
+ """Initialize a normal distribution.
73
+
74
+ Args:
75
+ loc_n: Mean of the distribution
76
+ scale_n: Standard deviation of the distribution
77
+ std_clip: Minimum standard deviation
78
+ """
79
+ self.loc_n = loc_n
80
+ self.scale_n = jnp.clip(scale_n, min=std_clip)
74
81
 
75
82
  def log_prob(self, x: Array) -> Array:
76
- return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale) - (x - self.loc) ** 2 / (2 * self.scale**2)
83
+ return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale_n) - (x - self.loc_n) ** 2 / (2 * self.scale_n**2)
77
84
 
78
85
  def sample(self, key: PRNGKeyArray) -> Array:
79
- return self.loc + self.scale * jax.random.normal(key, self.loc.shape)
86
+ return self.loc_n + self.scale_n * jax.random.normal(key, self.loc_n.shape)
80
87
 
81
88
  def mode(self) -> Array:
82
- return self.loc
89
+ return self.loc_n
83
90
 
84
91
  def entropy(self) -> Array:
85
- return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale)
92
+ return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale_n)
86
93
 
87
94
 
88
95
  class MixtureOfGaussians(Distribution):
89
- def __init__(self, means_nm: Array, stds_nm: Array, logits_nm: Array) -> None:
96
+ def __init__(
97
+ self,
98
+ means_nm: Array,
99
+ stds_nm: Array,
100
+ logits_nm: Array,
101
+ std_clip: float = STD_CLIP,
102
+ logit_clip: float = LOGIT_CLIP,
103
+ ) -> None:
90
104
  """Initialize a mixture of Gaussians.
91
105
 
92
106
  Args:
93
107
  means_nm: Array of shape (..., n_components) containing means
94
108
  stds_nm: Array of shape (..., n_components) containing standard deviations
95
109
  logits_nm: Array of shape (..., n_components) containing mixing logits
110
+ std_clip: Minimum standard deviation
111
+ logit_clip: Clipping value for logits
96
112
  """
97
113
  self.means_nm = means_nm
98
- self.stds_nm = stds_nm
99
- self.logits_nm = logits_nm
100
-
101
- def log_prob(self, x: Array) -> Array:
102
- """Compute log probability of the mixture.
114
+ self.stds_nm = jnp.clip(stds_nm, min=std_clip)
115
+ self.logits_nm = jnp.clip(logits_nm, -logit_clip, logit_clip)
103
116
 
104
- Args:
105
- x: Array of shape (...,) containing values to evaluate
106
-
107
- Returns:
108
- Log probabilities of shape (...,)
109
- """
117
+ def log_prob(self, x_n: Array) -> Array:
110
118
  # Expand x to match component dimensions
111
- x_expanded = x[..., None] # Shape: (..., 1)
119
+ x_n_expanded = x_n[..., None] # Shape: (..., 1)
112
120
 
113
121
  # Compute log probabilities for each component
114
122
  component_log_probs = (
115
123
  -0.5 * jnp.log(2 * jnp.pi)
116
124
  - jnp.log(self.stds_nm)
117
- - (x_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
125
+ - (x_n_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
118
126
  )
119
127
 
120
128
  # Compute mixing weights
@@ -123,16 +131,7 @@ class MixtureOfGaussians(Distribution):
123
131
  # Combine using log-sum-exp trick for numerical stability
124
132
  return jax.scipy.special.logsumexp(component_log_probs + mixing_logits, axis=-1)
125
133
 
126
- def sample(self, key: PRNGKeyArray) -> Array:
127
- """Sample from the mixture of Gaussians.
128
-
129
- Args:
130
- key: PRNG key
131
-
132
- Returns:
133
- Samples of shape (...,) where ... are the batch dimensions
134
- """
135
- # Sample component indices
134
+ def sample(self, key: PRNGKeyArray) -> Array: # Sample component indices
136
135
  component_key, sample_key = jax.random.split(key)
137
136
  component_indices = jax.random.categorical(component_key, self.logits_nm, axis=-1)
138
137
 
@@ -153,8 +152,8 @@ class MixtureOfGaussians(Distribution):
153
152
  noise = jax.random.normal(sample_key, selected_means.shape)
154
153
 
155
154
  # Reshape back to original batch shape
156
- samples = selected_means + selected_stds * noise
157
- return samples.reshape(batch_shape)
155
+ samples_n = selected_means + selected_stds * noise
156
+ return samples_n.reshape(batch_shape)
158
157
 
159
158
  def mode(self) -> Array:
160
159
  """Return the mode of the mixture (approximate - returns mean of highest weight component)."""
@@ -177,7 +177,6 @@ class TrainConfig(
177
177
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
178
178
  updates_per_step: int = field(1, help="Number of updates to perform per step")
179
179
  random_seed: int = field(1337, help="Random seed for the task")
180
- global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
181
180
 
182
181
 
183
182
  Config = TypeVar("Config", bound=TrainConfig)
@@ -1,6 +1,6 @@
1
1
  """Utils for accessing, modifying, and otherwise manipulating pytrees."""
2
2
 
3
- from typing import TypeVar
3
+ from typing import Mapping, Sequence, TypeVar
4
4
 
5
5
  import chex
6
6
  import equinox as eqx
@@ -258,11 +258,18 @@ def tuple_insert(t: tuple[T, ...], index: int, value: T) -> tuple[T, ...]:
258
258
  def get_pytree_mapping(pytree: PyTree) -> dict[str, Array]:
259
259
  leaves: dict[str, Array] = {}
260
260
 
261
+ def _get_str(thing: PyTree) -> str:
262
+ if isinstance(thing, str):
263
+ return thing
264
+ if isinstance(thing, Sequence):
265
+ return "/".join(_get_str(x) for x in thing)
266
+ if isinstance(thing, Mapping):
267
+ return "/".join(f"{_get_str(k)}:{_get_str(v)}" for k, v in thing.items())
268
+ return str(thing)
269
+
261
270
  def _get_leaf(path: tuple, x: PyTree) -> None:
262
271
  if isinstance(x, jnp.ndarray):
263
- # Convert path tuple to string, e.g. (1, 'a', 2) -> '1/a/2'
264
- path_str = "/".join(str(p) for p in path)
265
- leaves[path_str] = x
272
+ leaves[_get_str(path)] = x
266
273
 
267
274
  jax.tree.map_with_path(_get_leaf, pytree)
268
275
  return leaves
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.9
3
+ Version: 0.3.11
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes