xax 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.7"
15
+ __version__ = "0.3.9"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -31,6 +31,10 @@ __all__ = [
31
31
  "TransformerBlock",
32
32
  "TransformerCache",
33
33
  "TransformerStack",
34
+ "Categorical",
35
+ "Distribution",
36
+ "MixtureOfGaussians",
37
+ "Normal",
34
38
  "FourierEmbeddings",
35
39
  "IdentityPositionalEmbeddings",
36
40
  "LearnedPositionalEmbeddings",
@@ -219,6 +223,10 @@ NAME_MAP: dict[str, str] = {
219
223
  "TransformerBlock": "nn.attention",
220
224
  "TransformerCache": "nn.attention",
221
225
  "TransformerStack": "nn.attention",
226
+ "Categorical": "nn.distributions",
227
+ "Distribution": "nn.distributions",
228
+ "MixtureOfGaussians": "nn.distributions",
229
+ "Normal": "nn.distributions",
222
230
  "FourierEmbeddings": "nn.embeddings",
223
231
  "IdentityPositionalEmbeddings": "nn.embeddings",
224
232
  "LearnedPositionalEmbeddings": "nn.embeddings",
@@ -405,6 +413,7 @@ if IMPORT_ALL or TYPE_CHECKING:
405
413
  TransformerCache,
406
414
  TransformerStack,
407
415
  )
416
+ from xax.nn.distributions import Categorical, Distribution, MixtureOfGaussians, Normal
408
417
  from xax.nn.embeddings import (
409
418
  EmbeddingKind,
410
419
  FourierEmbeddings,
@@ -0,0 +1,181 @@
1
+ """Defines some probability distribution helper functions.
2
+
3
+ In general, it is preferrable to use Distrax or another library, but we wanted
4
+ to have a simple interface of our own so that we can quickly upgrade Jax
5
+ versions (since Distrax is tied pretty closely to Tensorflow).
6
+ """
7
+
8
+ __all__ = [
9
+ "Distribution",
10
+ "Categorical",
11
+ "Normal",
12
+ "MixtureOfGaussians",
13
+ ]
14
+
15
+ from abc import ABC, abstractmethod
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from jaxtyping import Array, PRNGKeyArray
20
+
21
+
22
+ class Distribution(ABC):
23
+ @abstractmethod
24
+ def log_prob(self, x: Array) -> Array: ...
25
+
26
+ @abstractmethod
27
+ def sample(self, key: PRNGKeyArray) -> Array: ...
28
+
29
+ @abstractmethod
30
+ def mode(self) -> Array: ...
31
+
32
+ @abstractmethod
33
+ def entropy(self) -> Array: ...
34
+
35
+
36
+ class Categorical(Distribution):
37
+ def __init__(self, logits_n: Array) -> None:
38
+ self.logits_n = logits_n
39
+
40
+ @property
41
+ def num_categories(self) -> int:
42
+ return self.logits_n.shape[-1]
43
+
44
+ def log_prob(self, x: Array) -> Array:
45
+ """Compute log probability for specific categories.
46
+
47
+ Args:
48
+ x: Array of category indices
49
+
50
+ Returns:
51
+ Log probabilities for the given categories
52
+ """
53
+ log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
54
+ # Use advanced indexing to get the log probabilities for the given categories
55
+ return log_probs[x]
56
+
57
+ def sample(self, key: PRNGKeyArray) -> Array:
58
+ return jax.random.categorical(key, self.logits_n, axis=-1)
59
+
60
+ def mode(self) -> Array:
61
+ return self.logits_n.argmax(axis=-1)
62
+
63
+ def entropy(self) -> Array:
64
+ """Compute entropy of the categorical distribution."""
65
+ probs = jax.nn.softmax(self.logits_n, axis=-1)
66
+ log_probs = jax.nn.log_softmax(self.logits_n, axis=-1)
67
+ return -jnp.sum(probs * log_probs, axis=-1)
68
+
69
+
70
+ class Normal(Distribution):
71
+ def __init__(self, loc: Array, scale: Array) -> None:
72
+ self.loc = loc
73
+ self.scale = scale
74
+
75
+ def log_prob(self, x: Array) -> Array:
76
+ return -0.5 * jnp.log(2 * jnp.pi) - jnp.log(self.scale) - (x - self.loc) ** 2 / (2 * self.scale**2)
77
+
78
+ def sample(self, key: PRNGKeyArray) -> Array:
79
+ return self.loc + self.scale * jax.random.normal(key, self.loc.shape)
80
+
81
+ def mode(self) -> Array:
82
+ return self.loc
83
+
84
+ def entropy(self) -> Array:
85
+ return jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.scale)
86
+
87
+
88
+ class MixtureOfGaussians(Distribution):
89
+ def __init__(self, means_nm: Array, stds_nm: Array, logits_nm: Array) -> None:
90
+ """Initialize a mixture of Gaussians.
91
+
92
+ Args:
93
+ means_nm: Array of shape (..., n_components) containing means
94
+ stds_nm: Array of shape (..., n_components) containing standard deviations
95
+ logits_nm: Array of shape (..., n_components) containing mixing logits
96
+ """
97
+ self.means_nm = means_nm
98
+ self.stds_nm = stds_nm
99
+ self.logits_nm = logits_nm
100
+
101
+ def log_prob(self, x: Array) -> Array:
102
+ """Compute log probability of the mixture.
103
+
104
+ Args:
105
+ x: Array of shape (...,) containing values to evaluate
106
+
107
+ Returns:
108
+ Log probabilities of shape (...,)
109
+ """
110
+ # Expand x to match component dimensions
111
+ x_expanded = x[..., None] # Shape: (..., 1)
112
+
113
+ # Compute log probabilities for each component
114
+ component_log_probs = (
115
+ -0.5 * jnp.log(2 * jnp.pi)
116
+ - jnp.log(self.stds_nm)
117
+ - (x_expanded - self.means_nm) ** 2 / (2 * self.stds_nm**2)
118
+ )
119
+
120
+ # Compute mixing weights
121
+ mixing_logits = jax.nn.log_softmax(self.logits_nm, axis=-1)
122
+
123
+ # Combine using log-sum-exp trick for numerical stability
124
+ return jax.scipy.special.logsumexp(component_log_probs + mixing_logits, axis=-1)
125
+
126
+ def sample(self, key: PRNGKeyArray) -> Array:
127
+ """Sample from the mixture of Gaussians.
128
+
129
+ Args:
130
+ key: PRNG key
131
+
132
+ Returns:
133
+ Samples of shape (...,) where ... are the batch dimensions
134
+ """
135
+ # Sample component indices
136
+ component_key, sample_key = jax.random.split(key)
137
+ component_indices = jax.random.categorical(component_key, self.logits_nm, axis=-1)
138
+
139
+ # Sample from selected components using advanced indexing
140
+ # We need to handle the case where we have batch dimensions
141
+ batch_shape = self.means_nm.shape[:-1] # All dimensions except the last (components)
142
+
143
+ # Reshape for easier indexing
144
+ means_flat = self.means_nm.reshape(-1, self.means_nm.shape[-1])
145
+ stds_flat = self.stds_nm.reshape(-1, self.stds_nm.shape[-1])
146
+ indices_flat = component_indices.reshape(-1)
147
+
148
+ # Get selected means and stds
149
+ selected_means = means_flat[jnp.arange(len(indices_flat)), indices_flat]
150
+ selected_stds = stds_flat[jnp.arange(len(indices_flat)), indices_flat]
151
+
152
+ # Generate random noise
153
+ noise = jax.random.normal(sample_key, selected_means.shape)
154
+
155
+ # Reshape back to original batch shape
156
+ samples = selected_means + selected_stds * noise
157
+ return samples.reshape(batch_shape)
158
+
159
+ def mode(self) -> Array:
160
+ """Return the mode of the mixture (approximate - returns mean of highest weight component)."""
161
+ mixing_weights = jax.nn.softmax(self.logits_nm, axis=-1)
162
+ max_weight_idx = jnp.argmax(mixing_weights, axis=-1)
163
+
164
+ # Use advanced indexing to get the means of the highest weight components
165
+ batch_shape = self.means_nm.shape[:-1]
166
+ means_flat = self.means_nm.reshape(-1, self.means_nm.shape[-1])
167
+ indices_flat = max_weight_idx.reshape(-1)
168
+
169
+ selected_means = means_flat[jnp.arange(len(indices_flat)), indices_flat]
170
+ return selected_means.reshape(batch_shape)
171
+
172
+ def entropy(self) -> Array:
173
+ """Compute entropy of the mixture (approximate)."""
174
+ mixing_weights = jax.nn.softmax(self.logits_nm, axis=-1)
175
+ component_entropies = jnp.log(2 * jnp.pi * jnp.e) + jnp.log(self.stds_nm)
176
+
177
+ # Weighted sum of component entropies plus mixing entropy
178
+ weighted_entropies = jnp.sum(mixing_weights * component_entropies, axis=-1)
179
+ mixing_entropy = -jnp.sum(mixing_weights * jnp.log(mixing_weights + 1e-8), axis=-1)
180
+
181
+ return weighted_entropies + mixing_entropy
@@ -110,7 +110,12 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
110
110
  "or `get_data_iterator` to return an iterator for the given dataset."
111
111
  )
112
112
 
113
- def get_dataloader(self, dataset: Dataset[T, Tc_co], phase: Phase) -> Dataloader[T, Tc_co]:
113
+ def get_dataloader(
114
+ self,
115
+ dataset: Dataset[T, Tc_co],
116
+ phase: Phase,
117
+ prefetch_factor: int | None = None,
118
+ ) -> Dataloader[T, Tc_co]:
114
119
  debugging = self.config.debug_dataloader
115
120
  if debugging:
116
121
  logger.warning("Parallel dataloaders disabled in debugging mode")
@@ -135,7 +140,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
135
140
  dataset=dataset,
136
141
  batch_size=self.config.batch_size,
137
142
  num_workers=0 if debugging else cfg.num_workers,
138
- prefetch_factor=cfg.prefetch_factor,
143
+ prefetch_factor=cfg.prefetch_factor if prefetch_factor is None else prefetch_factor,
139
144
  mp_manager=self.multiprocessing_manager,
140
145
  dataloader_worker_init_fn=self.dataloader_worker_init_fn,
141
146
  collate_worker_init_fn=self.collate_worker_init_fn,
xax/task/mixins/train.py CHANGED
@@ -60,7 +60,7 @@ from xax.utils.experiments import (
60
60
  get_state_file_string,
61
61
  get_training_code,
62
62
  )
63
- from xax.utils.jax import jit as xax_jit
63
+ from xax.utils.jax import jit as xax_jit, scan as xax_scan
64
64
  from xax.utils.logging import LOG_PING, LOG_STATUS
65
65
  from xax.utils.pytree import get_pytree_param_count
66
66
  from xax.utils.text import highlight_exception_message, show_info
@@ -175,6 +175,7 @@ class TrainConfig(
175
175
  valid_first_n_seconds: float | None = field(60.0, help="Run first validation after N seconds")
176
176
  max_steps: int | None = field(None, help="Maximum number of steps to run")
177
177
  step_kind: str = field("step", help=f"How to measure a step; one of [{', '.join(get_args(StepKind))}]")
178
+ updates_per_step: int = field(1, help="Number of updates to perform per step")
178
179
  random_seed: int = field(1337, help="Random seed for the task")
179
180
  global_grad_clip: float = field(value=10.0, help="The maximum gradient norm to clip to.")
180
181
 
@@ -597,6 +598,7 @@ class TrainMixin(
597
598
  metrics = self.compute_metrics(model, batch, output, loss, state)
598
599
  return loss, (output, metrics)
599
600
 
601
+ @xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
600
602
  def update(
601
603
  self,
602
604
  model_arr: PyTree,
@@ -609,44 +611,9 @@ class TrainMixin(
609
611
  grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
610
612
  grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
611
613
  grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
612
- model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
613
- return model_arr, opt_state, output, metrics | grad_metrics
614
-
615
- @xax_jit(static_argnames=["self", "optimizer"], jit_level=3)
616
- def apply_gradients_with_clipping(
617
- self,
618
- model_arr: PyTree,
619
- grads: PyTree,
620
- optimizer: optax.GradientTransformation,
621
- opt_state: optax.OptState,
622
- ) -> tuple[PyTree, optax.OptState, dict[str, Array]]:
623
- grad_norm = optax.global_norm(grads)
624
- grad_metrics = {"grad_norm": grad_norm}
625
-
626
- def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
627
- # Clip gradients based on global norm, similar to optax.clip_by_global_norm
628
- trigger = jnp.squeeze(grad_norm < self.config.global_grad_clip)
629
-
630
- def clip_fn(t: Array) -> Array:
631
- return jax.lax.select(trigger, t, (t / grad_norm.astype(t.dtype)) * self.config.global_grad_clip)
632
-
633
- grads = jax.tree.map(clip_fn, grads)
634
-
635
- # Apply the gradient updates.
636
- updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
637
- new_model_arr = eqx.apply_updates(model_arr, updates)
638
- return new_model_arr, new_opt_state
639
-
640
- # Don't apply updates if the gradient is NaN or Inf.
641
- new_model_arr, new_opt_state = jax.lax.cond(
642
- jnp.isnan(grad_norm) | jnp.isinf(grad_norm),
643
- lambda *_: (model_arr, opt_state),
644
- apply,
645
- grads,
646
- grad_norm,
647
- )
648
-
649
- return new_model_arr, new_opt_state, grad_metrics
614
+ updates, opt_state = optimizer.update(grads, opt_state, model_arr)
615
+ model_arr = eqx.apply_updates(model_arr, updates)
616
+ return model_arr, opt_state, output, metrics
650
617
 
651
618
  def get_size_of_batch(self, batch: Batch) -> int | None:
652
619
  """Gets the batch size for the current batch.
@@ -729,11 +696,36 @@ class TrainMixin(
729
696
  model_static: PyTree,
730
697
  optimizer: optax.GradientTransformation,
731
698
  opt_state: optax.OptState,
732
- batch: Batch,
699
+ batches: Batch,
733
700
  state: State,
734
701
  ) -> tuple[PyTree, optax.OptState, Output, FrozenDict[str, Array]]:
735
- model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
736
- return model_arr, opt_state, output, FrozenDict(metrics)
702
+ def update_fn(
703
+ carry: tuple[PyTree, optax.OptState],
704
+ batch: Batch,
705
+ ) -> tuple[tuple[PyTree, optax.OptState], tuple[Output, FrozenDict[str, Array]]]:
706
+ model_arr, opt_state = carry
707
+ model_arr, opt_state, output, metrics = self.update(
708
+ model_arr,
709
+ model_static,
710
+ optimizer,
711
+ opt_state,
712
+ batch,
713
+ state,
714
+ )
715
+ return (model_arr, opt_state), (output, FrozenDict(metrics))
716
+
717
+ (model_arr, opt_state), (output, metrics) = xax_scan(
718
+ update_fn,
719
+ (model_arr, opt_state),
720
+ batches,
721
+ jit_level=3,
722
+ )
723
+
724
+ # Only get the final output and metrics.
725
+ output = jax.tree.map(lambda x: x[-1], output)
726
+ metrics = jax.tree.map(lambda x: x[-1], metrics)
727
+
728
+ return model_arr, opt_state, output, metrics
737
729
 
738
730
  @xax_jit(static_argnames=["self", "model_static"], jit_level=3)
739
731
  def val_step(
@@ -775,40 +767,36 @@ class TrainMixin(
775
767
  output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
776
768
  self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
777
769
 
778
- state = state.replace(
779
- num_steps=state.num_steps + 1,
780
- num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
781
- )
782
-
783
770
  state = state.replace(
771
+ num_steps=state.num_steps + 1,
772
+ num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
784
773
  elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
785
774
  )
786
775
 
787
776
  with ContextTimer() as timer:
788
777
  state = self.on_step_start(state)
789
778
  state = state.replace(phase="train")
790
- train_batch = next(train_pf)
779
+ train_batches = list(itertools.islice(train_pf, self.config.updates_per_step))
791
780
  model_arr, opt_state, output, metrics = self.train_step(
792
781
  model_arr=model_arr,
793
782
  model_static=model_static,
794
783
  optimizer=optimizer,
795
784
  opt_state=opt_state,
796
- batch=train_batch,
785
+ batches=jax.tree.map(lambda *xs: jnp.stack(xs, axis=0), *train_batches),
797
786
  state=state,
798
787
  )
799
- self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
800
-
801
- state = state.replace(
802
- num_steps=state.num_steps + 1,
803
- num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
804
- )
805
-
788
+ self.log_step(eqx.combine(model_arr, model_static), train_batches[-1], output, metrics, state)
806
789
  state = self.on_step_end(state)
807
790
 
808
791
  state = state.replace(
792
+ num_steps=state.num_steps + 1,
793
+ num_samples=state.num_samples + (self.get_size_of_batch(train_batches[-1]) or 0),
809
794
  elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
810
795
  )
811
796
 
797
+ if state.num_steps <= 3:
798
+ logger.log(LOG_PING, "Step %d took %.2f second", state.num_steps, timer.elapsed_time)
799
+
812
800
  if self.should_checkpoint(state):
813
801
  model = eqx.combine(model_arr, model_static)
814
802
  self.save_checkpoint(models=[model], optimizers=[optimizer], opt_states=[opt_state], state=state)
@@ -827,7 +815,7 @@ class TrainMixin(
827
815
  pass
828
816
 
829
817
  train_ds = self.get_dataset("train")
830
- train_dl = self.get_dataloader(train_ds, "train")
818
+ train_dl = self.get_dataloader(train_ds, "train", prefetch_factor=self.config.updates_per_step + 1)
831
819
  train_pf = self.get_prefetcher(train_dl)
832
820
 
833
821
  try:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.7
3
+ Version: 0.3.9
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=YCDjLRwliJCyEmNFC56PNQXV9Vn9Fr13VJS_am4h3To,16336
1
+ xax/__init__.py,sha256=PLNyZ5fOm0f2JduTMauNH2jqxN4g0GAZ9MSNzEsAQS4,16665
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -9,6 +9,7 @@ xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
9
9
  xax/core/state.py,sha256=_gtINsRc310Bu_HuIYsDoOKTZa6DgU2tz0IOKkdnY9Q,3813
10
10
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  xax/nn/attention.py,sha256=m6yEoRqf7-wLgrEltaR6CxF_Cody0MaNtAkuKk39qJI,31176
12
+ xax/nn/distributions.py,sha256=096IDvoJ0ZA4SqcfgNSmrICsGcsKVcTAh0Vl6SwN3-o,6343
12
13
  xax/nn/embeddings.py,sha256=8tAuAPdkVj-U5IwtRZKHA0WYMFRbpCuwyAxcChdKhbE,11784
13
14
  xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
14
15
  xax/nn/geom.py,sha256=c9K52vLm-V-15CRqMNx0OmqsWfb3PHQxXW4OSx9kCAk,10635
@@ -36,13 +37,13 @@ xax/task/mixins/artifacts.py,sha256=R-y3p7__zJHlHDqwDVAZysg2ZmebCJbqAx_xGT2Xpd0,
36
37
  xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
37
38
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
38
39
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
39
- xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
40
+ xax/task/mixins/data_loader.py,sha256=BKfOVWXR70vbyHMFlnlUiQQHXHH5zTj5WtmsymNCFB4,6722
40
41
  xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
41
42
  xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
42
43
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
43
44
  xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
44
45
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
45
- xax/task/mixins/train.py,sha256=bjBoigTCjbq9H4hcqIO32irHBc9rC2zkgXrnGNI2RtI,33266
46
+ xax/task/mixins/train.py,sha256=_kDpifLi1arSuT0ssFhBV0axpvLlQG3a97pohya0Eqc,32908
46
47
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
47
48
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
48
49
  xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
@@ -59,9 +60,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
59
60
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
61
  xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
61
62
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
62
- xax-0.3.7.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
63
- xax-0.3.7.dist-info/METADATA,sha256=8Zb0pvTJOjrCHK7giM2MbhlGCPREQewJK3GgRDQNWY0,1246
64
- xax-0.3.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
65
- xax-0.3.7.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
66
- xax-0.3.7.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
67
- xax-0.3.7.dist-info/RECORD,,
63
+ xax-0.3.9.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
64
+ xax-0.3.9.dist-info/METADATA,sha256=UKUqiAADutUzjtq4WpChJJzWDcfKBsrbWMmZu0LUqQk,1246
65
+ xax-0.3.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
66
+ xax-0.3.9.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
67
+ xax-0.3.9.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
68
+ xax-0.3.9.dist-info/RECORD,,
File without changes