xax 0.3.10__tar.gz → 0.3.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.10/xax.egg-info → xax-0.3.11}/PKG-INFO +1 -1
- {xax-0.3.10 → xax-0.3.11}/xax/__init__.py +1 -1
- {xax-0.3.10 → xax-0.3.11}/xax/nn/distributions.py +52 -53
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/train.py +0 -1
- {xax-0.3.10 → xax-0.3.11/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.3.10 → xax-0.3.11}/LICENSE +0 -0
- {xax-0.3.10 → xax-0.3.11}/MANIFEST.in +0 -0
- {xax-0.3.10 → xax-0.3.11}/README.md +0 -0
- {xax-0.3.10 → xax-0.3.11}/pyproject.toml +0 -0
- {xax-0.3.10 → xax-0.3.11}/setup.cfg +0 -0
- {xax-0.3.10 → xax-0.3.11}/setup.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/cli/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/cli/edit_config.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/core/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/core/conf.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/core/state.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/nn/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/nn/attention.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/nn/embeddings.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/nn/functions.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/nn/geom.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/nn/losses.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/nn/metrics.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/nn/parallel.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/nn/ssm.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/py.typed +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/requirements-dev.txt +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/requirements.txt +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/base.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/launchers/base.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/launchers/single_process.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/logger.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/loggers/json.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/loggers/state.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/process.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/script.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/task/task.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/data/collate.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/debugging.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/experiments.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/jax.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/logging.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/numpy.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/profile.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/pytree.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/text.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.10 → xax-0.3.11}/xax.egg-info/top_level.txt +0 -0
@@ -12,12 +12,16 @@ __all__ = [
|
|
12
12
|
"MixtureOfGaussians",
|
13
13
|
]
|
14
14
|
|
15
|
+
import math
|
15
16
|
from abc import ABC, abstractmethod
|
16
17
|
|
17
18
|
import jax
|
18
19
|
import jax.numpy as jnp
|
19
20
|
from jaxtyping import Array, PRNGKeyArray
|
20
21
|
|
22
|
+
STD_CLIP = 1e-6
|
23
|
+
LOGIT_CLIP = math.log(1e4)
|
24
|
+
|
21
25
|
|
22
26
|
class Distribution(ABC):
|
23
27
|
@abstractmethod
|
@@ -34,87 +38,91 @@ class Distribution(ABC):
|
|
34
38
|
|
35
39
|
|
36
40
|
class Categorical(Distribution):
|
37
|
-
def __init__(self,
|
38
|
-
|
41
|
+
def __init__(self, logits_nc: Array, logit_clip: float = LOGIT_CLIP) -> None:
|
42
|
+
"""Initialize a categorical distribution.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
logits_nc: Array of shape (..., n_categories) containing logits
|
46
|
+
logit_clip: Clipping value for logits
|
47
|
+
"""
|
48
|
+
self.logits_nc = jnp.clip(logits_nc, -logit_clip, logit_clip)
|
39
49
|
|
40
50
|
@property
|
41
51
|
def num_categories(self) -> int:
|
42
|
-
return self.
|
52
|
+
return self.logits_nc.shape[-1]
|
43
53
|
|
44
|
-
def log_prob(self,
|
45
|
-
|
46
|
-
|
47
|
-
Args:
|
48
|
-
x: Array of category indices
|
49
|
-
|
50
|
-
Returns:
|
51
|
-
Log probabilities for the given categories
|
52
|
-
"""
|
53
|
-
log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
|
54
|
-
# Use advanced indexing to get the log probabilities for the given categories
|
55
|
-
return log_probs[x]
|
54
|
+
def log_prob(self, x_n: Array) -> Array:
|
55
|
+
log_probs_n = jax.nn.log_softmax(self.logits_nc, axis=-1)
|
56
|
+
return log_probs_n[x_n]
|
56
57
|
|
57
58
|
def sample(self, key: PRNGKeyArray) -> Array:
|
58
|
-
return jax.random.categorical(key, self.
|
59
|
+
return jax.random.categorical(key, self.logits_nc, axis=-1)
|
59
60
|
|
60
61
|
def mode(self) -> Array:
|
61
|
-
return self.
|
62
|
+
return self.logits_nc.argmax(axis=-1)
|
62
63
|
|
63
64
|
def entropy(self) -> Array:
|
64
|
-
|
65
|
-
|
66
|
-
log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
|
65
|
+
probs = jax.nn.softmax(self.logits_nc, axis=-1)
|
66
|
+
log_probs = jax.nn.log_softmax(self.logits_nc, axis=-1)
|
67
67
|
return -jnp.sum(probs * log_probs, axis=-1)
|
68
68
|
|
69
69
|
|
70
70
|
class Normal(Distribution):
|
71
|
-
def __init__(self,
|
72
|
-
|
73
|
-
|
71
|
+
def __init__(self, loc_n: Array, scale_n: Array, std_clip: float = STD_CLIP) -> None:
|
72
|
+
"""Initialize a normal distribution.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
loc_n: Mean of the distribution
|
76
|
+
scale_n: Standard deviation of the distribution
|
77
|
+
std_clip: Minimum standard deviation
|
78
|
+
"""
|
79
|
+
self.loc_n = loc_n
|
80
|
+
self.scale_n = jnp.clip(scale_n, min=std_clip)
|
74
81
|
|
75
82
|
def log_prob(self, x: Array) -> Array:
|
76
|
-
return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.
|
83
|
+
return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale_n) - (x - self.loc_n) ** 2 / (2 * self.scale_n**2)
|
77
84
|
|
78
85
|
def sample(self, key: PRNGKeyArray) -> Array:
|
79
|
-
return self.
|
86
|
+
return self.loc_n + self.scale_n * jax.random.normal(key, self.loc_n.shape)
|
80
87
|
|
81
88
|
def mode(self) -> Array:
|
82
|
-
return self.
|
89
|
+
return self.loc_n
|
83
90
|
|
84
91
|
def entropy(self) -> Array:
|
85
|
-
return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.
|
92
|
+
return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale_n)
|
86
93
|
|
87
94
|
|
88
95
|
class MixtureOfGaussians(Distribution):
|
89
|
-
def __init__(
|
96
|
+
def __init__(
|
97
|
+
self,
|
98
|
+
means_nm: Array,
|
99
|
+
stds_nm: Array,
|
100
|
+
logits_nm: Array,
|
101
|
+
std_clip: float = STD_CLIP,
|
102
|
+
logit_clip: float = LOGIT_CLIP,
|
103
|
+
) -> None:
|
90
104
|
"""Initialize a mixture of Gaussians.
|
91
105
|
|
92
106
|
Args:
|
93
107
|
means_nm: Array of shape (..., n_components) containing means
|
94
108
|
stds_nm: Array of shape (..., n_components) containing standard deviations
|
95
109
|
logits_nm: Array of shape (..., n_components) containing mixing logits
|
110
|
+
std_clip: Minimum standard deviation
|
111
|
+
logit_clip: Clipping value for logits
|
96
112
|
"""
|
97
113
|
self.means_nm = means_nm
|
98
|
-
self.stds_nm = stds_nm
|
99
|
-
self.logits_nm = logits_nm
|
100
|
-
|
101
|
-
def log_prob(self, x: Array) -> Array:
|
102
|
-
"""Compute log probability of the mixture.
|
114
|
+
self.stds_nm = jnp.clip(stds_nm, min=std_clip)
|
115
|
+
self.logits_nm = jnp.clip(logits_nm, -logit_clip, logit_clip)
|
103
116
|
|
104
|
-
|
105
|
-
x: Array of shape (...,) containing values to evaluate
|
106
|
-
|
107
|
-
Returns:
|
108
|
-
Log probabilities of shape (...,)
|
109
|
-
"""
|
117
|
+
def log_prob(self, x_n: Array) -> Array:
|
110
118
|
# Expand x to match component dimensions
|
111
|
-
|
119
|
+
x_n_expanded = x_n[..., None] # Shape: (..., 1)
|
112
120
|
|
113
121
|
# Compute log probabilities for each component
|
114
122
|
component_log_probs = (
|
115
123
|
-0.5 * jnp.log(2 * jnp.pi)
|
116
124
|
- jnp.log(self.stds_nm)
|
117
|
-
- (
|
125
|
+
- (x_n_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
|
118
126
|
)
|
119
127
|
|
120
128
|
# Compute mixing weights
|
@@ -123,16 +131,7 @@ class MixtureOfGaussians(Distribution):
|
|
123
131
|
# Combine using log-sum-exp trick for numerical stability
|
124
132
|
return jax.scipy.special.logsumexp(component_log_probs + mixing_logits, axis=-1)
|
125
133
|
|
126
|
-
def sample(self, key: PRNGKeyArray) -> Array:
|
127
|
-
"""Sample from the mixture of Gaussians.
|
128
|
-
|
129
|
-
Args:
|
130
|
-
key: PRNG key
|
131
|
-
|
132
|
-
Returns:
|
133
|
-
Samples of shape (...,) where ... are the batch dimensions
|
134
|
-
"""
|
135
|
-
# Sample component indices
|
134
|
+
def sample(self, key: PRNGKeyArray) -> Array: # Sample component indices
|
136
135
|
component_key, sample_key = jax.random.split(key)
|
137
136
|
component_indices = jax.random.categorical(component_key, self.logits_nm, axis=-1)
|
138
137
|
|
@@ -153,8 +152,8 @@ class MixtureOfGaussians(Distribution):
|
|
153
152
|
noise = jax.random.normal(sample_key, selected_means.shape)
|
154
153
|
|
155
154
|
# Reshape back to original batch shape
|
156
|
-
|
157
|
-
return
|
155
|
+
samples_n = selected_means + selected_stds * noise
|
156
|
+
return samples_n.reshape(batch_shape)
|
158
157
|
|
159
158
|
def mode(self) -> Array:
|
160
159
|
"""Return the mode of the mixture (approximate - returns mean of highest weight component)."""
|
@@ -177,7 +177,6 @@ class TrainConfig(
|
|
177
177
|
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
178
178
|
updates_per_step: int = field(1, help="Number of updates to perform per step")
|
179
179
|
random_seed: int = field(1337, help="Random seed for the task")
|
180
|
-
global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
|
181
180
|
|
182
181
|
|
183
182
|
Config = TypeVar("Config", bound=TrainConfig)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|