xax 0.3.2__tar.gz → 0.3.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.2/xax.egg-info → xax-0.3.4}/PKG-INFO +1 -1
- {xax-0.3.2 → xax-0.3.4}/xax/__init__.py +28 -10
- xax-0.3.4/xax/nn/attention.py +888 -0
- {xax-0.3.2 → xax-0.3.4}/xax/nn/geom.py +82 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/runnable.py +1 -2
- {xax-0.3.2 → xax-0.3.4/xax.egg-info}/PKG-INFO +1 -1
- xax-0.3.2/xax/nn/attention.py +0 -738
- {xax-0.3.2 → xax-0.3.4}/LICENSE +0 -0
- {xax-0.3.2 → xax-0.3.4}/MANIFEST.in +0 -0
- {xax-0.3.2 → xax-0.3.4}/README.md +0 -0
- {xax-0.3.2 → xax-0.3.4}/pyproject.toml +0 -0
- {xax-0.3.2 → xax-0.3.4}/setup.cfg +0 -0
- {xax-0.3.2 → xax-0.3.4}/setup.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/cli/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/cli/edit_config.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/core/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/core/conf.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/core/state.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/nn/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/nn/embeddings.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/nn/functions.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/nn/losses.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/nn/metrics.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/nn/parallel.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/nn/ssm.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/py.typed +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/requirements-dev.txt +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/requirements.txt +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/base.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/launchers/base.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/launchers/single_process.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/logger.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/loggers/json.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/loggers/state.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/logger.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/process.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/mixins/train.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/script.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/task/task.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/data/collate.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/debugging.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/experiments.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/jax.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/logging.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/numpy.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/profile.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/pytree.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/text.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax.egg-info/requires.txt +0 -0
- {xax-0.3.2 → xax-0.3.4}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.3.
|
15
|
+
__version__ = "0.3.4"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -23,10 +23,14 @@ __all__ = [
|
|
23
23
|
"get_run_dir",
|
24
24
|
"load_user_config",
|
25
25
|
"State",
|
26
|
+
"AttentionCache",
|
27
|
+
"AttentionCacheDict",
|
26
28
|
"CrossAttentionBlock",
|
27
29
|
"SelfAttentionBlock",
|
28
30
|
"Transformer",
|
29
31
|
"TransformerBlock",
|
32
|
+
"TransformerCache",
|
33
|
+
"TransformerStack",
|
30
34
|
"FourierEmbeddings",
|
31
35
|
"IdentityPositionalEmbeddings",
|
32
36
|
"LearnedPositionalEmbeddings",
|
@@ -42,11 +46,12 @@ __all__ = [
|
|
42
46
|
"euler_to_quat",
|
43
47
|
"get_projected_gravity_vector_from_quat",
|
44
48
|
"normalize",
|
49
|
+
"quat_mul",
|
45
50
|
"quat_to_euler",
|
46
51
|
"quat_to_rotmat",
|
47
|
-
"quat_mul",
|
48
52
|
"rotate_vector_by_quat",
|
49
53
|
"rotation6d_to_rotation_matrix",
|
54
|
+
"rotation_matrix_to_quat",
|
50
55
|
"rotation_matrix_to_rotation6d",
|
51
56
|
"cross_entropy",
|
52
57
|
"cast_norm_type",
|
@@ -205,10 +210,14 @@ NAME_MAP: dict[str, str] = {
|
|
205
210
|
"get_run_dir": "core.conf",
|
206
211
|
"load_user_config": "core.conf",
|
207
212
|
"State": "core.state",
|
213
|
+
"AttentionCache": "nn.attention",
|
214
|
+
"AttentionCacheDict": "nn.attention",
|
208
215
|
"CrossAttentionBlock": "nn.attention",
|
209
216
|
"SelfAttentionBlock": "nn.attention",
|
210
217
|
"Transformer": "nn.attention",
|
211
218
|
"TransformerBlock": "nn.attention",
|
219
|
+
"TransformerCache": "nn.attention",
|
220
|
+
"TransformerStack": "nn.attention",
|
212
221
|
"FourierEmbeddings": "nn.embeddings",
|
213
222
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
214
223
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -224,11 +233,12 @@ NAME_MAP: dict[str, str] = {
|
|
224
233
|
"euler_to_quat": "nn.geom",
|
225
234
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
226
235
|
"normalize": "nn.geom",
|
236
|
+
"quat_mul": "nn.geom",
|
227
237
|
"quat_to_euler": "nn.geom",
|
228
238
|
"quat_to_rotmat": "nn.geom",
|
229
|
-
"quat_mul": "nn.geom",
|
230
239
|
"rotate_vector_by_quat": "nn.geom",
|
231
240
|
"rotation6d_to_rotation_matrix": "nn.geom",
|
241
|
+
"rotation_matrix_to_quat": "nn.geom",
|
232
242
|
"rotation_matrix_to_rotation6d": "nn.geom",
|
233
243
|
"cross_entropy": "nn.losses",
|
234
244
|
"cast_norm_type": "nn.metrics",
|
@@ -360,6 +370,9 @@ NAME_MAP.update(
|
|
360
370
|
},
|
361
371
|
)
|
362
372
|
|
373
|
+
# In NAME_MAP
|
374
|
+
NAME_MAP["TransformerStack"] = "nn.attention"
|
375
|
+
|
363
376
|
|
364
377
|
def __getattr__(name: str) -> object:
|
365
378
|
if name not in NAME_MAP:
|
@@ -380,7 +393,16 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
380
393
|
load_user_config,
|
381
394
|
)
|
382
395
|
from xax.core.state import Phase, State
|
383
|
-
from xax.nn.attention import
|
396
|
+
from xax.nn.attention import (
|
397
|
+
AttentionCache,
|
398
|
+
AttentionCacheDict,
|
399
|
+
CrossAttentionBlock,
|
400
|
+
SelfAttentionBlock,
|
401
|
+
Transformer,
|
402
|
+
TransformerBlock,
|
403
|
+
TransformerCache,
|
404
|
+
TransformerStack,
|
405
|
+
)
|
384
406
|
from xax.nn.embeddings import (
|
385
407
|
EmbeddingKind,
|
386
408
|
FourierEmbeddings,
|
@@ -405,15 +427,11 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
405
427
|
quat_to_rotmat,
|
406
428
|
rotate_vector_by_quat,
|
407
429
|
rotation6d_to_rotation_matrix,
|
430
|
+
rotation_matrix_to_quat,
|
408
431
|
rotation_matrix_to_rotation6d,
|
409
432
|
)
|
410
433
|
from xax.nn.losses import cross_entropy
|
411
|
-
from xax.nn.metrics import
|
412
|
-
NormType,
|
413
|
-
cast_norm_type,
|
414
|
-
dynamic_time_warping,
|
415
|
-
get_norm,
|
416
|
-
)
|
434
|
+
from xax.nn.metrics import NormType, cast_norm_type, dynamic_time_warping, get_norm
|
417
435
|
from xax.nn.parallel import is_master
|
418
436
|
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
419
437
|
from xax.task.base import RawConfigType
|