xax 0.3.3__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {xax-0.3.3/xax.egg-info → xax-0.3.4}/PKG-INFO +1 -1
  2. {xax-0.3.3 → xax-0.3.4}/xax/__init__.py +23 -8
  3. xax-0.3.4/xax/nn/attention.py +888 -0
  4. {xax-0.3.3 → xax-0.3.4/xax.egg-info}/PKG-INFO +1 -1
  5. xax-0.3.3/xax/nn/attention.py +0 -738
  6. {xax-0.3.3 → xax-0.3.4}/LICENSE +0 -0
  7. {xax-0.3.3 → xax-0.3.4}/MANIFEST.in +0 -0
  8. {xax-0.3.3 → xax-0.3.4}/README.md +0 -0
  9. {xax-0.3.3 → xax-0.3.4}/pyproject.toml +0 -0
  10. {xax-0.3.3 → xax-0.3.4}/setup.cfg +0 -0
  11. {xax-0.3.3 → xax-0.3.4}/setup.py +0 -0
  12. {xax-0.3.3 → xax-0.3.4}/xax/cli/__init__.py +0 -0
  13. {xax-0.3.3 → xax-0.3.4}/xax/cli/edit_config.py +0 -0
  14. {xax-0.3.3 → xax-0.3.4}/xax/core/__init__.py +0 -0
  15. {xax-0.3.3 → xax-0.3.4}/xax/core/conf.py +0 -0
  16. {xax-0.3.3 → xax-0.3.4}/xax/core/state.py +0 -0
  17. {xax-0.3.3 → xax-0.3.4}/xax/nn/__init__.py +0 -0
  18. {xax-0.3.3 → xax-0.3.4}/xax/nn/embeddings.py +0 -0
  19. {xax-0.3.3 → xax-0.3.4}/xax/nn/functions.py +0 -0
  20. {xax-0.3.3 → xax-0.3.4}/xax/nn/geom.py +0 -0
  21. {xax-0.3.3 → xax-0.3.4}/xax/nn/losses.py +0 -0
  22. {xax-0.3.3 → xax-0.3.4}/xax/nn/metrics.py +0 -0
  23. {xax-0.3.3 → xax-0.3.4}/xax/nn/parallel.py +0 -0
  24. {xax-0.3.3 → xax-0.3.4}/xax/nn/ssm.py +0 -0
  25. {xax-0.3.3 → xax-0.3.4}/xax/py.typed +0 -0
  26. {xax-0.3.3 → xax-0.3.4}/xax/requirements-dev.txt +0 -0
  27. {xax-0.3.3 → xax-0.3.4}/xax/requirements.txt +0 -0
  28. {xax-0.3.3 → xax-0.3.4}/xax/task/__init__.py +0 -0
  29. {xax-0.3.3 → xax-0.3.4}/xax/task/base.py +0 -0
  30. {xax-0.3.3 → xax-0.3.4}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.3.3 → xax-0.3.4}/xax/task/launchers/base.py +0 -0
  32. {xax-0.3.3 → xax-0.3.4}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.3.3 → xax-0.3.4}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.3.3 → xax-0.3.4}/xax/task/logger.py +0 -0
  35. {xax-0.3.3 → xax-0.3.4}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.3.3 → xax-0.3.4}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.3.3 → xax-0.3.4}/xax/task/loggers/json.py +0 -0
  38. {xax-0.3.3 → xax-0.3.4}/xax/task/loggers/state.py +0 -0
  39. {xax-0.3.3 → xax-0.3.4}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.3.3 → xax-0.3.4}/xax/task/loggers/tensorboard.py +0 -0
  41. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/__init__.py +0 -0
  42. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/artifacts.py +0 -0
  43. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/checkpointing.py +0 -0
  44. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/compile.py +0 -0
  45. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/cpu_stats.py +0 -0
  46. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/gpu_stats.py +0 -0
  48. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/process.py +0 -0
  50. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/runnable.py +0 -0
  51. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/step_wrapper.py +0 -0
  52. {xax-0.3.3 → xax-0.3.4}/xax/task/mixins/train.py +0 -0
  53. {xax-0.3.3 → xax-0.3.4}/xax/task/script.py +0 -0
  54. {xax-0.3.3 → xax-0.3.4}/xax/task/task.py +0 -0
  55. {xax-0.3.3 → xax-0.3.4}/xax/utils/__init__.py +0 -0
  56. {xax-0.3.3 → xax-0.3.4}/xax/utils/data/__init__.py +0 -0
  57. {xax-0.3.3 → xax-0.3.4}/xax/utils/data/collate.py +0 -0
  58. {xax-0.3.3 → xax-0.3.4}/xax/utils/debugging.py +0 -0
  59. {xax-0.3.3 → xax-0.3.4}/xax/utils/experiments.py +0 -0
  60. {xax-0.3.3 → xax-0.3.4}/xax/utils/jax.py +0 -0
  61. {xax-0.3.3 → xax-0.3.4}/xax/utils/jaxpr.py +0 -0
  62. {xax-0.3.3 → xax-0.3.4}/xax/utils/logging.py +0 -0
  63. {xax-0.3.3 → xax-0.3.4}/xax/utils/numpy.py +0 -0
  64. {xax-0.3.3 → xax-0.3.4}/xax/utils/profile.py +0 -0
  65. {xax-0.3.3 → xax-0.3.4}/xax/utils/pytree.py +0 -0
  66. {xax-0.3.3 → xax-0.3.4}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.3.3 → xax-0.3.4}/xax/utils/text.py +0 -0
  68. {xax-0.3.3 → xax-0.3.4}/xax/utils/types/__init__.py +0 -0
  69. {xax-0.3.3 → xax-0.3.4}/xax/utils/types/frozen_dict.py +0 -0
  70. {xax-0.3.3 → xax-0.3.4}/xax/utils/types/hashable_array.py +0 -0
  71. {xax-0.3.3 → xax-0.3.4}/xax.egg-info/SOURCES.txt +0 -0
  72. {xax-0.3.3 → xax-0.3.4}/xax.egg-info/dependency_links.txt +0 -0
  73. {xax-0.3.3 → xax-0.3.4}/xax.egg-info/entry_points.txt +0 -0
  74. {xax-0.3.3 → xax-0.3.4}/xax.egg-info/requires.txt +0 -0
  75. {xax-0.3.3 → xax-0.3.4}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.3"
15
+ __version__ = "0.3.4"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -23,10 +23,14 @@ __all__ = [
23
23
  "get_run_dir",
24
24
  "load_user_config",
25
25
  "State",
26
+ "AttentionCache",
27
+ "AttentionCacheDict",
26
28
  "CrossAttentionBlock",
27
29
  "SelfAttentionBlock",
28
30
  "Transformer",
29
31
  "TransformerBlock",
32
+ "TransformerCache",
33
+ "TransformerStack",
30
34
  "FourierEmbeddings",
31
35
  "IdentityPositionalEmbeddings",
32
36
  "LearnedPositionalEmbeddings",
@@ -206,10 +210,14 @@ NAME_MAP: dict[str, str] = {
206
210
  "get_run_dir": "core.conf",
207
211
  "load_user_config": "core.conf",
208
212
  "State": "core.state",
213
+ "AttentionCache": "nn.attention",
214
+ "AttentionCacheDict": "nn.attention",
209
215
  "CrossAttentionBlock": "nn.attention",
210
216
  "SelfAttentionBlock": "nn.attention",
211
217
  "Transformer": "nn.attention",
212
218
  "TransformerBlock": "nn.attention",
219
+ "TransformerCache": "nn.attention",
220
+ "TransformerStack": "nn.attention",
213
221
  "FourierEmbeddings": "nn.embeddings",
214
222
  "IdentityPositionalEmbeddings": "nn.embeddings",
215
223
  "LearnedPositionalEmbeddings": "nn.embeddings",
@@ -362,6 +370,9 @@ NAME_MAP.update(
362
370
  },
363
371
  )
364
372
 
373
+ # In NAME_MAP
374
+ NAME_MAP["TransformerStack"] = "nn.attention"
375
+
365
376
 
366
377
  def __getattr__(name: str) -> object:
367
378
  if name not in NAME_MAP:
@@ -382,7 +393,16 @@ if IMPORT_ALL or TYPE_CHECKING:
382
393
  load_user_config,
383
394
  )
384
395
  from xax.core.state import Phase, State
385
- from xax.nn.attention import CrossAttentionBlock, SelfAttentionBlock, Transformer, TransformerBlock
396
+ from xax.nn.attention import (
397
+ AttentionCache,
398
+ AttentionCacheDict,
399
+ CrossAttentionBlock,
400
+ SelfAttentionBlock,
401
+ Transformer,
402
+ TransformerBlock,
403
+ TransformerCache,
404
+ TransformerStack,
405
+ )
386
406
  from xax.nn.embeddings import (
387
407
  EmbeddingKind,
388
408
  FourierEmbeddings,
@@ -411,12 +431,7 @@ if IMPORT_ALL or TYPE_CHECKING:
411
431
  rotation_matrix_to_rotation6d,
412
432
  )
413
433
  from xax.nn.losses import cross_entropy
414
- from xax.nn.metrics import (
415
- NormType,
416
- cast_norm_type,
417
- dynamic_time_warping,
418
- get_norm,
419
- )
434
+ from xax.nn.metrics import NormType, cast_norm_type, dynamic_time_warping, get_norm
420
435
  from xax.nn.parallel import is_master
421
436
  from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
422
437
  from xax.task.base import RawConfigType