xax 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +10 -1
- xax/nn/distributions.py +181 -0
- xax/task/mixins/data_loader.py +7 -2
- xax/task/mixins/train.py +45 -57
- {xax-0.3.7.dist-info → xax-0.3.9.dist-info}/METADATA +1 -1
- {xax-0.3.7.dist-info → xax-0.3.9.dist-info}/RECORD +10 -9
- {xax-0.3.7.dist-info → xax-0.3.9.dist-info}/WHEEL +0 -0
- {xax-0.3.7.dist-info → xax-0.3.9.dist-info}/entry_points.txt +0 -0
- {xax-0.3.7.dist-info → xax-0.3.9.dist-info}/licenses/LICENSE +0 -0
- {xax-0.3.7.dist-info → xax-0.3.9.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.9"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -31,6 +31,10 @@ __all__ = [
|
|
31
31
|
"TransformerBlock",
|
32
32
|
"TransformerCache",
|
33
33
|
"TransformerStack",
|
34
|
+
"Categorical",
|
35
|
+
"Distribution",
|
36
|
+
"MixtureOfGaussians",
|
37
|
+
"Normal",
|
34
38
|
"FourierEmbeddings",
|
35
39
|
"IdentityPositionalEmbeddings",
|
36
40
|
"LearnedPositionalEmbeddings",
|
@@ -219,6 +223,10 @@ NAME_MAP: dict[str, str] = {
|
|
219
223
|
"TransformerBlock": "nn.attention",
|
220
224
|
"TransformerCache": "nn.attention",
|
221
225
|
"TransformerStack": "nn.attention",
|
226
|
+
"Categorical": "nn.distributions",
|
227
|
+
"Distribution": "nn.distributions",
|
228
|
+
"MixtureOfGaussians": "nn.distributions",
|
229
|
+
"Normal": "nn.distributions",
|
222
230
|
"FourierEmbeddings": "nn.embeddings",
|
223
231
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
224
232
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -405,6 +413,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
405
413
|
TransformerCache,
|
406
414
|
TransformerStack,
|
407
415
|
)
|
416
|
+
from xax.nn.distributions import Categorical, Distribution, MixtureOfGaussians, Normal
|
408
417
|
from xax.nn.embeddings import (
|
409
418
|
EmbeddingKind,
|
410
419
|
FourierEmbeddings,
|
xax/nn/distributions.py
ADDED
@@ -0,0 +1,181 @@
|
|
1
|
+
"""Defines some probability distribution helper functions.
|
2
|
+
|
3
|
+
In general, it is preferrable to use Distrax or another library, but we wanted
|
4
|
+
to have a simple interface of our own so that we can quickly upgrade Jax
|
5
|
+
versions (since Distrax is tied pretty closely to Tensorflow).
|
6
|
+
"""
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"Distribution",
|
10
|
+
"Categorical",
|
11
|
+
"Normal",
|
12
|
+
"MixtureOfGaussians",
|
13
|
+
]
|
14
|
+
|
15
|
+
from abc import ABC, abstractmethod
|
16
|
+
|
17
|
+
import jax
|
18
|
+
import jax.numpy as jnp
|
19
|
+
from jaxtyping import Array, PRNGKeyArray
|
20
|
+
|
21
|
+
|
22
|
+
class Distribution(ABC):
|
23
|
+
@abstractmethod
|
24
|
+
def log_prob(self, x: Array) -> Array: ...
|
25
|
+
|
26
|
+
@abstractmethod
|
27
|
+
def sample(self, key: PRNGKeyArray) -> Array: ...
|
28
|
+
|
29
|
+
@abstractmethod
|
30
|
+
def mode(self) -> Array: ...
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def entropy(self) -> Array: ...
|
34
|
+
|
35
|
+
|
36
|
+
class Categorical(Distribution):
|
37
|
+
def __init__(self, logits_n: Array) -> None:
|
38
|
+
self.logits_n = logits_n
|
39
|
+
|
40
|
+
@property
|
41
|
+
def num_categories(self) -> int:
|
42
|
+
return self.logits_n.shape[-1]
|
43
|
+
|
44
|
+
def log_prob(self, x: Array) -> Array:
|
45
|
+
"""Compute log probability for specific categories.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
x: Array of category indices
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
Log probabilities for the given categories
|
52
|
+
"""
|
53
|
+
log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
|
54
|
+
# Use advanced indexing to get the log probabilities for the given categories
|
55
|
+
return log_probs[x]
|
56
|
+
|
57
|
+
def sample(self, key: PRNGKeyArray) -> Array:
|
58
|
+
return jax.random.categorical(key, self.logits_n, axis=-1)
|
59
|
+
|
60
|
+
def mode(self) -> Array:
|
61
|
+
return self.logits_n.argmax(axis=-1)
|
62
|
+
|
63
|
+
def entropy(self) -> Array:
|
64
|
+
"""Compute entropy of the categorical distribution."""
|
65
|
+
probs = jax.nn.softmax(self.logits_n, axis=-1)
|
66
|
+
log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
|
67
|
+
return -jnp.sum(probs * log_probs, axis=-1)
|
68
|
+
|
69
|
+
|
70
|
+
class Normal(Distribution):
|
71
|
+
def __init__(self, loc: Array, scale: Array) -> None:
|
72
|
+
self.loc = loc
|
73
|
+
self.scale = scale
|
74
|
+
|
75
|
+
def log_prob(self, x: Array) -> Array:
|
76
|
+
return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale) - (x - self.loc) ** 2 / (2 * self.scale**2)
|
77
|
+
|
78
|
+
def sample(self, key: PRNGKeyArray) -> Array:
|
79
|
+
return self.loc + self.scale * jax.random.normal(key, self.loc.shape)
|
80
|
+
|
81
|
+
def mode(self) -> Array:
|
82
|
+
return self.loc
|
83
|
+
|
84
|
+
def entropy(self) -> Array:
|
85
|
+
return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale)
|
86
|
+
|
87
|
+
|
88
|
+
class MixtureOfGaussians(Distribution):
|
89
|
+
def __init__(self, means_nm: Array, stds_nm: Array, logits_nm: Array) -> None:
|
90
|
+
"""Initialize a mixture of Gaussians.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
means_nm: Array of shape (..., n_components) containing means
|
94
|
+
stds_nm: Array of shape (..., n_components) containing standard deviations
|
95
|
+
logits_nm: Array of shape (..., n_components) containing mixing logits
|
96
|
+
"""
|
97
|
+
self.means_nm = means_nm
|
98
|
+
self.stds_nm = stds_nm
|
99
|
+
self.logits_nm = logits_nm
|
100
|
+
|
101
|
+
def log_prob(self, x: Array) -> Array:
|
102
|
+
"""Compute log probability of the mixture.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
x: Array of shape (...,) containing values to evaluate
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
Log probabilities of shape (...,)
|
109
|
+
"""
|
110
|
+
# Expand x to match component dimensions
|
111
|
+
x_expanded = x[..., None] # Shape: (..., 1)
|
112
|
+
|
113
|
+
# Compute log probabilities for each component
|
114
|
+
component_log_probs = (
|
115
|
+
-0.5 * jnp.log(2 * jnp.pi)
|
116
|
+
- jnp.log(self.stds_nm)
|
117
|
+
- (x_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
|
118
|
+
)
|
119
|
+
|
120
|
+
# Compute mixing weights
|
121
|
+
mixing_logits = jax.nn.log_softmax(self.logits_nm, axis=-1)
|
122
|
+
|
123
|
+
# Combine using log-sum-exp trick for numerical stability
|
124
|
+
return jax.scipy.special.logsumexp(component_log_probs + mixing_logits, axis=-1)
|
125
|
+
|
126
|
+
def sample(self, key: PRNGKeyArray) -> Array:
|
127
|
+
"""Sample from the mixture of Gaussians.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
key: PRNG key
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Samples of shape (...,) where ... are the batch dimensions
|
134
|
+
"""
|
135
|
+
# Sample component indices
|
136
|
+
component_key, sample_key = jax.random.split(key)
|
137
|
+
component_indices = jax.random.categorical(component_key, self.logits_nm, axis=-1)
|
138
|
+
|
139
|
+
# Sample from selected components using advanced indexing
|
140
|
+
# We need to handle the case where we have batch dimensions
|
141
|
+
batch_shape = self.means_nm.shape[:-1] # All dimensions except the last (components)
|
142
|
+
|
143
|
+
# Reshape for easier indexing
|
144
|
+
means_flat = self.means_nm.reshape(-1, self.means_nm.shape[-1])
|
145
|
+
stds_flat = self.stds_nm.reshape(-1, self.stds_nm.shape[-1])
|
146
|
+
indices_flat = component_indices.reshape(-1)
|
147
|
+
|
148
|
+
# Get selected means and stds
|
149
|
+
selected_means = means_flat[jnp.arange(len(indices_flat)), indices_flat]
|
150
|
+
selected_stds = stds_flat[jnp.arange(len(indices_flat)), indices_flat]
|
151
|
+
|
152
|
+
# Generate random noise
|
153
|
+
noise = jax.random.normal(sample_key, selected_means.shape)
|
154
|
+
|
155
|
+
# Reshape back to original batch shape
|
156
|
+
samples = selected_means + selected_stds * noise
|
157
|
+
return samples.reshape(batch_shape)
|
158
|
+
|
159
|
+
def mode(self) -> Array:
|
160
|
+
"""Return the mode of the mixture (approximate - returns mean of highest weight component)."""
|
161
|
+
mixing_weights = jax.nn.softmax(self.logits_nm, axis=-1)
|
162
|
+
max_weight_idx = jnp.argmax(mixing_weights, axis=-1)
|
163
|
+
|
164
|
+
# Use advanced indexing to get the means of the highest weight components
|
165
|
+
batch_shape = self.means_nm.shape[:-1]
|
166
|
+
means_flat = self.means_nm.reshape(-1, self.means_nm.shape[-1])
|
167
|
+
indices_flat = max_weight_idx.reshape(-1)
|
168
|
+
|
169
|
+
selected_means = means_flat[jnp.arange(len(indices_flat)), indices_flat]
|
170
|
+
return selected_means.reshape(batch_shape)
|
171
|
+
|
172
|
+
def entropy(self) -> Array:
|
173
|
+
"""Compute entropy of the mixture (approximate)."""
|
174
|
+
mixing_weights = jax.nn.softmax(self.logits_nm, axis=-1)
|
175
|
+
component_entropies = jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.stds_nm)
|
176
|
+
|
177
|
+
# Weighted sum of component entropies plus mixing entropy
|
178
|
+
weighted_entropies = jnp.sum(mixing_weights * component_entropies, axis=-1)
|
179
|
+
mixing_entropy = -jnp.sum(mixing_weights * jnp.log(mixing_weights + 1e-8), axis=-1)
|
180
|
+
|
181
|
+
return weighted_entropies + mixing_entropy
|
xax/task/mixins/data_loader.py
CHANGED
@@ -110,7 +110,12 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
110
110
|
"or `get_data_iterator` to return an iterator for the given dataset."
|
111
111
|
)
|
112
112
|
|
113
|
-
def get_dataloader(
|
113
|
+
def get_dataloader(
|
114
|
+
self,
|
115
|
+
dataset: Dataset[T, Tc_co],
|
116
|
+
phase: Phase,
|
117
|
+
prefetch_factor: int | None = None,
|
118
|
+
) -> Dataloader[T, Tc_co]:
|
114
119
|
debugging = self.config.debug_dataloader
|
115
120
|
if debugging:
|
116
121
|
logger.warning("Parallel dataloaders disabled in debugging mode")
|
@@ -135,7 +140,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
135
140
|
dataset=dataset,
|
136
141
|
batch_size=self.config.batch_size,
|
137
142
|
num_workers=0 if debugging else cfg.num_workers,
|
138
|
-
prefetch_factor=cfg.prefetch_factor,
|
143
|
+
prefetch_factor=cfg.prefetch_factor if prefetch_factor is None else prefetch_factor,
|
139
144
|
mp_manager=self.multiprocessing_manager,
|
140
145
|
dataloader_worker_init_fn=self.dataloader_worker_init_fn,
|
141
146
|
collate_worker_init_fn=self.collate_worker_init_fn,
|
xax/task/mixins/train.py
CHANGED
@@ -60,7 +60,7 @@ from xax.utils.experiments import (
|
|
60
60
|
get_state_file_string,
|
61
61
|
get_training_code,
|
62
62
|
)
|
63
|
-
from xax.utils.jax import jit as xax_jit
|
63
|
+
from xax.utils.jax import jit as xax_jit, scan as xax_scan
|
64
64
|
from xax.utils.logging import LOG_PING, LOG_STATUS
|
65
65
|
from xax.utils.pytree import get_pytree_param_count
|
66
66
|
from xax.utils.text import highlight_exception_message, show_info
|
@@ -175,6 +175,7 @@ class TrainConfig(
|
|
175
175
|
valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
|
176
176
|
max_steps: int | None = field(None, help="Maximum number of steps to run")
|
177
177
|
step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
|
178
|
+
updates_per_step: int = field(1, help="Number of updates to perform per step")
|
178
179
|
random_seed: int = field(1337, help="Random seed for the task")
|
179
180
|
global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
|
180
181
|
|
@@ -597,6 +598,7 @@ class TrainMixin(
|
|
597
598
|
metrics = self.compute_metrics(model, batch, output, loss, state)
|
598
599
|
return loss, (output, metrics)
|
599
600
|
|
601
|
+
@xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
|
600
602
|
def update(
|
601
603
|
self,
|
602
604
|
model_arr: PyTree,
|
@@ -609,44 +611,9 @@ class TrainMixin(
|
|
609
611
|
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
610
612
|
grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
|
611
613
|
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
@xax_jit(static_argnames=["self", "optimizer"], jit_level=3)
|
616
|
-
def apply_gradients_with_clipping(
|
617
|
-
self,
|
618
|
-
model_arr: PyTree,
|
619
|
-
grads: PyTree,
|
620
|
-
optimizer: optax.GradientTransformation,
|
621
|
-
opt_state: optax.OptState,
|
622
|
-
) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
|
623
|
-
grad_norm = optax.global_norm(grads)
|
624
|
-
grad_metrics = {"grad_norm": grad_norm}
|
625
|
-
|
626
|
-
def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
|
627
|
-
# Clip gradients based on global norm, similar to optax.clip_by_global_norm
|
628
|
-
trigger = jnp.squeeze(grad_norm < self.config.global_grad_clip)
|
629
|
-
|
630
|
-
def clip_fn(t: Array) -> Array:
|
631
|
-
return jax.lax.select(trigger, t, (t / grad_norm.astype(t.dtype)) * self.config.global_grad_clip)
|
632
|
-
|
633
|
-
grads = jax.tree.map(clip_fn, grads)
|
634
|
-
|
635
|
-
# Apply the gradient updates.
|
636
|
-
updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
|
637
|
-
new_model_arr = eqx.apply_updates(model_arr, updates)
|
638
|
-
return new_model_arr, new_opt_state
|
639
|
-
|
640
|
-
# Don't apply updates if the gradient is NaN or Inf.
|
641
|
-
new_model_arr, new_opt_state = jax.lax.cond(
|
642
|
-
jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
|
643
|
-
lambda *_: (model_arr, opt_state),
|
644
|
-
apply,
|
645
|
-
grads,
|
646
|
-
grad_norm,
|
647
|
-
)
|
648
|
-
|
649
|
-
return new_model_arr, new_opt_state, grad_metrics
|
614
|
+
updates, opt_state = optimizer.update(grads, opt_state, model_arr)
|
615
|
+
model_arr = eqx.apply_updates(model_arr, updates)
|
616
|
+
return model_arr, opt_state, output, metrics
|
650
617
|
|
651
618
|
def get_size_of_batch(self, batch: Batch) -> int | None:
|
652
619
|
"""Gets the batch size for the current batch.
|
@@ -729,11 +696,36 @@ class TrainMixin(
|
|
729
696
|
model_static: PyTree,
|
730
697
|
optimizer: optax.GradientTransformation,
|
731
698
|
opt_state: optax.OptState,
|
732
|
-
|
699
|
+
batches: Batch,
|
733
700
|
state: State,
|
734
701
|
) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
|
735
|
-
|
736
|
-
|
702
|
+
def update_fn(
|
703
|
+
carry: tuple[PyTree, optax.OptState],
|
704
|
+
batch: Batch,
|
705
|
+
) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
|
706
|
+
model_arr, opt_state = carry
|
707
|
+
model_arr, opt_state, output, metrics = self.update(
|
708
|
+
model_arr,
|
709
|
+
model_static,
|
710
|
+
optimizer,
|
711
|
+
opt_state,
|
712
|
+
batch,
|
713
|
+
state,
|
714
|
+
)
|
715
|
+
return (model_arr, opt_state), (output, FrozenDict(metrics))
|
716
|
+
|
717
|
+
(model_arr, opt_state), (output, metrics) = xax_scan(
|
718
|
+
update_fn,
|
719
|
+
(model_arr, opt_state),
|
720
|
+
batches,
|
721
|
+
jit_level=3,
|
722
|
+
)
|
723
|
+
|
724
|
+
# Only get the final output and metrics.
|
725
|
+
output = jax.tree.map(lambda x: x[-1], output)
|
726
|
+
metrics = jax.tree.map(lambda x: x[-1], metrics)
|
727
|
+
|
728
|
+
return model_arr, opt_state, output, metrics
|
737
729
|
|
738
730
|
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
739
731
|
def val_step(
|
@@ -775,40 +767,36 @@ class TrainMixin(
|
|
775
767
|
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
776
768
|
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
777
769
|
|
778
|
-
state = state.replace(
|
779
|
-
num_steps=state.num_steps + 1,
|
780
|
-
num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
|
781
|
-
)
|
782
|
-
|
783
770
|
state = state.replace(
|
771
|
+
num_steps=state.num_steps + 1,
|
772
|
+
num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
|
784
773
|
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
785
774
|
)
|
786
775
|
|
787
776
|
with ContextTimer() as timer:
|
788
777
|
state = self.on_step_start(state)
|
789
778
|
state = state.replace(phase="train")
|
790
|
-
|
779
|
+
train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
|
791
780
|
model_arr, opt_state, output, metrics = self.train_step(
|
792
781
|
model_arr=model_arr,
|
793
782
|
model_static=model_static,
|
794
783
|
optimizer=optimizer,
|
795
784
|
opt_state=opt_state,
|
796
|
-
|
785
|
+
batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
|
797
786
|
state=state,
|
798
787
|
)
|
799
|
-
self.log_step(eqx.combine(model_arr, model_static),
|
800
|
-
|
801
|
-
state = state.replace(
|
802
|
-
num_steps=state.num_steps + 1,
|
803
|
-
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
804
|
-
)
|
805
|
-
|
788
|
+
self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
|
806
789
|
state = self.on_step_end(state)
|
807
790
|
|
808
791
|
state = state.replace(
|
792
|
+
num_steps=state.num_steps + 1,
|
793
|
+
num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
|
809
794
|
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
810
795
|
)
|
811
796
|
|
797
|
+
if state.num_steps <= 3:
|
798
|
+
logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
|
799
|
+
|
812
800
|
if self.should_checkpoint(state):
|
813
801
|
model = eqx.combine(model_arr, model_static)
|
814
802
|
self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
|
@@ -827,7 +815,7 @@ class TrainMixin(
|
|
827
815
|
pass
|
828
816
|
|
829
817
|
train_ds = self.get_dataset("train")
|
830
|
-
train_dl = self.get_dataloader(train_ds, "train")
|
818
|
+
train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
|
831
819
|
train_pf = self.get_prefetcher(train_dl)
|
832
820
|
|
833
821
|
try:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=PLNyZ5fOm0f2JduTMauNH2jqxN4g0GAZ9MSNzEsAQS4,16665
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -9,6 +9,7 @@ xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
|
9
9
|
xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
|
12
|
+
xax/nn/distributions.py,sha256=096IDvoJ0ZA4SqcfgNSmrICsGcsKVcTAh0Vl6SwN3-o,6343
|
12
13
|
xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
|
13
14
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
14
15
|
xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
|
@@ -36,13 +37,13 @@ xax/task/mixins/artifacts.py,sha256=R-y3p7__zJHlHDqwDVAZysg2ZmebCJbqAx_xGT2Xpd0,
|
|
36
37
|
xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
|
37
38
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
38
39
|
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
39
|
-
xax/task/mixins/data_loader.py,sha256=
|
40
|
+
xax/task/mixins/data_loader.py,sha256=BKfOVWXR70vbyHMFlnlUiQQHXHH5zTj5WtmsymNCFB4,6722
|
40
41
|
xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
|
41
42
|
xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
|
42
43
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
43
44
|
xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
|
44
45
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
45
|
-
xax/task/mixins/train.py,sha256=
|
46
|
+
xax/task/mixins/train.py,sha256=_kDpifLi1arSuT0ssFhBV0axpvLlQG3a97pohya0Eqc,32908
|
46
47
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
47
48
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
48
49
|
xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
|
@@ -59,9 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
59
60
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
60
61
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
61
62
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
62
|
-
xax-0.3.
|
63
|
-
xax-0.3.
|
64
|
-
xax-0.3.
|
65
|
-
xax-0.3.
|
66
|
-
xax-0.3.
|
67
|
-
xax-0.3.
|
63
|
+
xax-0.3.9.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
64
|
+
xax-0.3.9.dist-info/METADATA,sha256=UKUqiAADutUzjtq4WpChJJzWDcfKBsrbWMmZu0LUqQk,1246
|
65
|
+
xax-0.3.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
66
|
+
xax-0.3.9.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
67
|
+
xax-0.3.9.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
68
|
+
xax-0.3.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|