xax 0.3.3__tar.gz → 0.3.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.3/xax.egg-info → xax-0.3.9}/PKG-INFO +1 -1
- {xax-0.3.3 → xax-0.3.9}/xax/__init__.py +35 -8
- xax-0.3.9/xax/nn/attention.py +940 -0
- xax-0.3.9/xax/nn/distributions.py +181 -0
- {xax-0.3.3 → xax-0.3.9}/xax/nn/embeddings.py +10 -10
- {xax-0.3.3 → xax-0.3.9}/xax/nn/geom.py +5 -5
- {xax-0.3.3 → xax-0.3.9}/xax/nn/ssm.py +6 -6
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/data_loader.py +7 -2
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/train.py +51 -58
- {xax-0.3.3 → xax-0.3.9}/xax/utils/pytree.py +13 -0
- {xax-0.3.3 → xax-0.3.9/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.3.3 → xax-0.3.9}/xax.egg-info/SOURCES.txt +1 -0
- xax-0.3.3/xax/nn/attention.py +0 -738
- {xax-0.3.3 → xax-0.3.9}/LICENSE +0 -0
- {xax-0.3.3 → xax-0.3.9}/MANIFEST.in +0 -0
- {xax-0.3.3 → xax-0.3.9}/README.md +0 -0
- {xax-0.3.3 → xax-0.3.9}/pyproject.toml +0 -0
- {xax-0.3.3 → xax-0.3.9}/setup.cfg +0 -0
- {xax-0.3.3 → xax-0.3.9}/setup.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/cli/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/cli/edit_config.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/core/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/core/conf.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/core/state.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/nn/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/nn/functions.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/nn/losses.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/nn/metrics.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/nn/parallel.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/py.typed +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/requirements-dev.txt +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/requirements.txt +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/base.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/launchers/base.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/launchers/single_process.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/logger.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/json.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/state.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/process.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/script.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/task/task.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/data/collate.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/debugging.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/experiments.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/jax.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/logging.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/numpy.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/profile.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/text.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.3 → xax-0.3.9}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.9"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -23,10 +23,18 @@ __all__ = [
|
|
23
23
|
"get_run_dir",
|
24
24
|
"load_user_config",
|
25
25
|
"State",
|
26
|
+
"AttentionCache",
|
27
|
+
"AttentionCacheDict",
|
26
28
|
"CrossAttentionBlock",
|
27
29
|
"SelfAttentionBlock",
|
28
30
|
"Transformer",
|
29
31
|
"TransformerBlock",
|
32
|
+
"TransformerCache",
|
33
|
+
"TransformerStack",
|
34
|
+
"Categorical",
|
35
|
+
"Distribution",
|
36
|
+
"MixtureOfGaussians",
|
37
|
+
"Normal",
|
30
38
|
"FourierEmbeddings",
|
31
39
|
"IdentityPositionalEmbeddings",
|
32
40
|
"LearnedPositionalEmbeddings",
|
@@ -132,6 +140,7 @@ __all__ = [
|
|
132
140
|
"compute_nan_ratio",
|
133
141
|
"flatten_array",
|
134
142
|
"flatten_pytree",
|
143
|
+
"get_pytree_mapping",
|
135
144
|
"get_pytree_param_count",
|
136
145
|
"pytree_has_nans",
|
137
146
|
"reshuffle_pytree",
|
@@ -206,10 +215,18 @@ NAME_MAP: dict[str, str] = {
|
|
206
215
|
"get_run_dir": "core.conf",
|
207
216
|
"load_user_config": "core.conf",
|
208
217
|
"State": "core.state",
|
218
|
+
"AttentionCache": "nn.attention",
|
219
|
+
"AttentionCacheDict": "nn.attention",
|
209
220
|
"CrossAttentionBlock": "nn.attention",
|
210
221
|
"SelfAttentionBlock": "nn.attention",
|
211
222
|
"Transformer": "nn.attention",
|
212
223
|
"TransformerBlock": "nn.attention",
|
224
|
+
"TransformerCache": "nn.attention",
|
225
|
+
"TransformerStack": "nn.attention",
|
226
|
+
"Categorical": "nn.distributions",
|
227
|
+
"Distribution": "nn.distributions",
|
228
|
+
"MixtureOfGaussians": "nn.distributions",
|
229
|
+
"Normal": "nn.distributions",
|
213
230
|
"FourierEmbeddings": "nn.embeddings",
|
214
231
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
215
232
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -315,6 +332,7 @@ NAME_MAP: dict[str, str] = {
|
|
315
332
|
"compute_nan_ratio": "utils.pytree",
|
316
333
|
"flatten_array": "utils.pytree",
|
317
334
|
"flatten_pytree": "utils.pytree",
|
335
|
+
"get_pytree_mapping": "utils.pytree",
|
318
336
|
"get_pytree_param_count": "utils.pytree",
|
319
337
|
"pytree_has_nans": "utils.pytree",
|
320
338
|
"reshuffle_pytree": "utils.pytree",
|
@@ -362,6 +380,9 @@ NAME_MAP.update(
|
|
362
380
|
},
|
363
381
|
)
|
364
382
|
|
383
|
+
# In NAME_MAP
|
384
|
+
NAME_MAP["TransformerStack"] = "nn.attention"
|
385
|
+
|
365
386
|
|
366
387
|
def __getattr__(name: str) -> object:
|
367
388
|
if name not in NAME_MAP:
|
@@ -382,7 +403,17 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
382
403
|
load_user_config,
|
383
404
|
)
|
384
405
|
from xax.core.state import Phase, State
|
385
|
-
from xax.nn.attention import
|
406
|
+
from xax.nn.attention import (
|
407
|
+
AttentionCache,
|
408
|
+
AttentionCacheDict,
|
409
|
+
CrossAttentionBlock,
|
410
|
+
SelfAttentionBlock,
|
411
|
+
Transformer,
|
412
|
+
TransformerBlock,
|
413
|
+
TransformerCache,
|
414
|
+
TransformerStack,
|
415
|
+
)
|
416
|
+
from xax.nn.distributions import Categorical, Distribution, MixtureOfGaussians, Normal
|
386
417
|
from xax.nn.embeddings import (
|
387
418
|
EmbeddingKind,
|
388
419
|
FourierEmbeddings,
|
@@ -411,12 +442,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
411
442
|
rotation_matrix_to_rotation6d,
|
412
443
|
)
|
413
444
|
from xax.nn.losses import cross_entropy
|
414
|
-
from xax.nn.metrics import
|
415
|
-
NormType,
|
416
|
-
cast_norm_type,
|
417
|
-
dynamic_time_warping,
|
418
|
-
get_norm,
|
419
|
-
)
|
445
|
+
from xax.nn.metrics import NormType, cast_norm_type, dynamic_time_warping, get_norm
|
420
446
|
from xax.nn.parallel import is_master
|
421
447
|
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
422
448
|
from xax.task.base import RawConfigType
|
@@ -494,6 +520,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
494
520
|
compute_nan_ratio,
|
495
521
|
flatten_array,
|
496
522
|
flatten_pytree,
|
523
|
+
get_pytree_mapping,
|
497
524
|
get_pytree_param_count,
|
498
525
|
pytree_has_nans,
|
499
526
|
reshuffle_pytree,
|