xax 0.3.3__tar.gz → 0.3.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.3/xax.egg-info → xax-0.3.5}/PKG-INFO +1 -1
- {xax-0.3.3 → xax-0.3.5}/xax/__init__.py +23 -8
- xax-0.3.5/xax/nn/attention.py +849 -0
- {xax-0.3.3 → xax-0.3.5}/xax/nn/embeddings.py +10 -10
- {xax-0.3.3 → xax-0.3.5}/xax/nn/geom.py +5 -5
- {xax-0.3.3 → xax-0.3.5}/xax/nn/ssm.py +6 -6
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/train.py +6 -1
- {xax-0.3.3 → xax-0.3.5/xax.egg-info}/PKG-INFO +1 -1
- xax-0.3.3/xax/nn/attention.py +0 -738
- {xax-0.3.3 → xax-0.3.5}/LICENSE +0 -0
- {xax-0.3.3 → xax-0.3.5}/MANIFEST.in +0 -0
- {xax-0.3.3 → xax-0.3.5}/README.md +0 -0
- {xax-0.3.3 → xax-0.3.5}/pyproject.toml +0 -0
- {xax-0.3.3 → xax-0.3.5}/setup.cfg +0 -0
- {xax-0.3.3 → xax-0.3.5}/setup.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/cli/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/cli/edit_config.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/core/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/core/conf.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/core/state.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/nn/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/nn/functions.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/nn/losses.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/nn/metrics.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/nn/parallel.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/py.typed +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/requirements-dev.txt +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/requirements.txt +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/base.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/launchers/base.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/launchers/single_process.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/logger.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/loggers/json.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/loggers/state.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/process.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/script.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/task/task.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/data/collate.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/debugging.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/experiments.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/jax.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/logging.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/numpy.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/profile.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/pytree.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/text.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.3 → xax-0.3.5}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.5"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -23,10 +23,14 @@ __all__ = [
|
|
23
23
|
"get_run_dir",
|
24
24
|
"load_user_config",
|
25
25
|
"State",
|
26
|
+
"AttentionCache",
|
27
|
+
"AttentionCacheDict",
|
26
28
|
"CrossAttentionBlock",
|
27
29
|
"SelfAttentionBlock",
|
28
30
|
"Transformer",
|
29
31
|
"TransformerBlock",
|
32
|
+
"TransformerCache",
|
33
|
+
"TransformerStack",
|
30
34
|
"FourierEmbeddings",
|
31
35
|
"IdentityPositionalEmbeddings",
|
32
36
|
"LearnedPositionalEmbeddings",
|
@@ -206,10 +210,14 @@ NAME_MAP: dict[str, str] = {
|
|
206
210
|
"get_run_dir": "core.conf",
|
207
211
|
"load_user_config": "core.conf",
|
208
212
|
"State": "core.state",
|
213
|
+
"AttentionCache": "nn.attention",
|
214
|
+
"AttentionCacheDict": "nn.attention",
|
209
215
|
"CrossAttentionBlock": "nn.attention",
|
210
216
|
"SelfAttentionBlock": "nn.attention",
|
211
217
|
"Transformer": "nn.attention",
|
212
218
|
"TransformerBlock": "nn.attention",
|
219
|
+
"TransformerCache": "nn.attention",
|
220
|
+
"TransformerStack": "nn.attention",
|
213
221
|
"FourierEmbeddings": "nn.embeddings",
|
214
222
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
215
223
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -362,6 +370,9 @@ NAME_MAP.update(
|
|
362
370
|
},
|
363
371
|
)
|
364
372
|
|
373
|
+
# In NAME_MAP
|
374
|
+
NAME_MAP["TransformerStack"] = "nn.attention"
|
375
|
+
|
365
376
|
|
366
377
|
def __getattr__(name: str) -> object:
|
367
378
|
if name not in NAME_MAP:
|
@@ -382,7 +393,16 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
382
393
|
load_user_config,
|
383
394
|
)
|
384
395
|
from xax.core.state import Phase, State
|
385
|
-
from xax.nn.attention import
|
396
|
+
from xax.nn.attention import (
|
397
|
+
AttentionCache,
|
398
|
+
AttentionCacheDict,
|
399
|
+
CrossAttentionBlock,
|
400
|
+
SelfAttentionBlock,
|
401
|
+
Transformer,
|
402
|
+
TransformerBlock,
|
403
|
+
TransformerCache,
|
404
|
+
TransformerStack,
|
405
|
+
)
|
386
406
|
from xax.nn.embeddings import (
|
387
407
|
EmbeddingKind,
|
388
408
|
FourierEmbeddings,
|
@@ -411,12 +431,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
411
431
|
rotation_matrix_to_rotation6d,
|
412
432
|
)
|
413
433
|
from xax.nn.losses import cross_entropy
|
414
|
-
from xax.nn.metrics import
|
415
|
-
NormType,
|
416
|
-
cast_norm_type,
|
417
|
-
dynamic_time_warping,
|
418
|
-
get_norm,
|
419
|
-
)
|
434
|
+
from xax.nn.metrics import NormType, cast_norm_type, dynamic_time_warping, get_norm
|
420
435
|
from xax.nn.parallel import is_master
|
421
436
|
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
422
437
|
from xax.task.base import RawConfigType
|