xax 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.2"
15
+ __version__ = "0.3.4"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -23,10 +23,14 @@ __all__ = [
23
23
  "get_run_dir",
24
24
  "load_user_config",
25
25
  "State",
26
+ "AttentionCache",
27
+ "AttentionCacheDict",
26
28
  "CrossAttentionBlock",
27
29
  "SelfAttentionBlock",
28
30
  "Transformer",
29
31
  "TransformerBlock",
32
+ "TransformerCache",
33
+ "TransformerStack",
30
34
  "FourierEmbeddings",
31
35
  "IdentityPositionalEmbeddings",
32
36
  "LearnedPositionalEmbeddings",
@@ -42,11 +46,12 @@ __all__ = [
42
46
  "euler_to_quat",
43
47
  "get_projected_gravity_vector_from_quat",
44
48
  "normalize",
49
+ "quat_mul",
45
50
  "quat_to_euler",
46
51
  "quat_to_rotmat",
47
- "quat_mul",
48
52
  "rotate_vector_by_quat",
49
53
  "rotation6d_to_rotation_matrix",
54
+ "rotation_matrix_to_quat",
50
55
  "rotation_matrix_to_rotation6d",
51
56
  "cross_entropy",
52
57
  "cast_norm_type",
@@ -205,10 +210,14 @@ NAME_MAP: dict[str, str] = {
205
210
  "get_run_dir": "core.conf",
206
211
  "load_user_config": "core.conf",
207
212
  "State": "core.state",
213
+ "AttentionCache": "nn.attention",
214
+ "AttentionCacheDict": "nn.attention",
208
215
  "CrossAttentionBlock": "nn.attention",
209
216
  "SelfAttentionBlock": "nn.attention",
210
217
  "Transformer": "nn.attention",
211
218
  "TransformerBlock": "nn.attention",
219
+ "TransformerCache": "nn.attention",
220
+ "TransformerStack": "nn.attention",
212
221
  "FourierEmbeddings": "nn.embeddings",
213
222
  "IdentityPositionalEmbeddings": "nn.embeddings",
214
223
  "LearnedPositionalEmbeddings": "nn.embeddings",
@@ -224,11 +233,12 @@ NAME_MAP: dict[str, str] = {
224
233
  "euler_to_quat": "nn.geom",
225
234
  "get_projected_gravity_vector_from_quat": "nn.geom",
226
235
  "normalize": "nn.geom",
236
+ "quat_mul": "nn.geom",
227
237
  "quat_to_euler": "nn.geom",
228
238
  "quat_to_rotmat": "nn.geom",
229
- "quat_mul": "nn.geom",
230
239
  "rotate_vector_by_quat": "nn.geom",
231
240
  "rotation6d_to_rotation_matrix": "nn.geom",
241
+ "rotation_matrix_to_quat": "nn.geom",
232
242
  "rotation_matrix_to_rotation6d": "nn.geom",
233
243
  "cross_entropy": "nn.losses",
234
244
  "cast_norm_type": "nn.metrics",
@@ -360,6 +370,9 @@ NAME_MAP.update(
360
370
  },
361
371
  )
362
372
 
373
+ # In NAME_MAP
374
+ NAME_MAP["TransformerStack"] = "nn.attention"
375
+
363
376
 
364
377
  def __getattr__(name: str) -> object:
365
378
  if name not in NAME_MAP:
@@ -380,7 +393,16 @@ if IMPORT_ALL or TYPE_CHECKING:
380
393
  load_user_config,
381
394
  )
382
395
  from xax.core.state import Phase, State
383
- from xax.nn.attention import CrossAttentionBlock, SelfAttentionBlock, Transformer, TransformerBlock
396
+ from xax.nn.attention import (
397
+ AttentionCache,
398
+ AttentionCacheDict,
399
+ CrossAttentionBlock,
400
+ SelfAttentionBlock,
401
+ Transformer,
402
+ TransformerBlock,
403
+ TransformerCache,
404
+ TransformerStack,
405
+ )
384
406
  from xax.nn.embeddings import (
385
407
  EmbeddingKind,
386
408
  FourierEmbeddings,
@@ -405,15 +427,11 @@ if IMPORT_ALL or TYPE_CHECKING:
405
427
  quat_to_rotmat,
406
428
  rotate_vector_by_quat,
407
429
  rotation6d_to_rotation_matrix,
430
+ rotation_matrix_to_quat,
408
431
  rotation_matrix_to_rotation6d,
409
432
  )
410
433
  from xax.nn.losses import cross_entropy
411
- from xax.nn.metrics import (
412
- NormType,
413
- cast_norm_type,
414
- dynamic_time_warping,
415
- get_norm,
416
- )
434
+ from xax.nn.metrics import NormType, cast_norm_type, dynamic_time_warping, get_norm
417
435
  from xax.nn.parallel import is_master
418
436
  from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
419
437
  from xax.task.base import RawConfigType