xax 0.3.3__tar.gz → 0.3.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {xax-0.3.3/xax.egg-info → xax-0.3.9}/PKG-INFO +1 -1
  2. {xax-0.3.3 → xax-0.3.9}/xax/__init__.py +35 -8
  3. xax-0.3.9/xax/nn/attention.py +940 -0
  4. xax-0.3.9/xax/nn/distributions.py +181 -0
  5. {xax-0.3.3 → xax-0.3.9}/xax/nn/embeddings.py +10 -10
  6. {xax-0.3.3 → xax-0.3.9}/xax/nn/geom.py +5 -5
  7. {xax-0.3.3 → xax-0.3.9}/xax/nn/ssm.py +6 -6
  8. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/data_loader.py +7 -2
  9. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/train.py +51 -58
  10. {xax-0.3.3 → xax-0.3.9}/xax/utils/pytree.py +13 -0
  11. {xax-0.3.3 → xax-0.3.9/xax.egg-info}/PKG-INFO +1 -1
  12. {xax-0.3.3 → xax-0.3.9}/xax.egg-info/SOURCES.txt +1 -0
  13. xax-0.3.3/xax/nn/attention.py +0 -738
  14. {xax-0.3.3 → xax-0.3.9}/LICENSE +0 -0
  15. {xax-0.3.3 → xax-0.3.9}/MANIFEST.in +0 -0
  16. {xax-0.3.3 → xax-0.3.9}/README.md +0 -0
  17. {xax-0.3.3 → xax-0.3.9}/pyproject.toml +0 -0
  18. {xax-0.3.3 → xax-0.3.9}/setup.cfg +0 -0
  19. {xax-0.3.3 → xax-0.3.9}/setup.py +0 -0
  20. {xax-0.3.3 → xax-0.3.9}/xax/cli/__init__.py +0 -0
  21. {xax-0.3.3 → xax-0.3.9}/xax/cli/edit_config.py +0 -0
  22. {xax-0.3.3 → xax-0.3.9}/xax/core/__init__.py +0 -0
  23. {xax-0.3.3 → xax-0.3.9}/xax/core/conf.py +0 -0
  24. {xax-0.3.3 → xax-0.3.9}/xax/core/state.py +0 -0
  25. {xax-0.3.3 → xax-0.3.9}/xax/nn/__init__.py +0 -0
  26. {xax-0.3.3 → xax-0.3.9}/xax/nn/functions.py +0 -0
  27. {xax-0.3.3 → xax-0.3.9}/xax/nn/losses.py +0 -0
  28. {xax-0.3.3 → xax-0.3.9}/xax/nn/metrics.py +0 -0
  29. {xax-0.3.3 → xax-0.3.9}/xax/nn/parallel.py +0 -0
  30. {xax-0.3.3 → xax-0.3.9}/xax/py.typed +0 -0
  31. {xax-0.3.3 → xax-0.3.9}/xax/requirements-dev.txt +0 -0
  32. {xax-0.3.3 → xax-0.3.9}/xax/requirements.txt +0 -0
  33. {xax-0.3.3 → xax-0.3.9}/xax/task/__init__.py +0 -0
  34. {xax-0.3.3 → xax-0.3.9}/xax/task/base.py +0 -0
  35. {xax-0.3.3 → xax-0.3.9}/xax/task/launchers/__init__.py +0 -0
  36. {xax-0.3.3 → xax-0.3.9}/xax/task/launchers/base.py +0 -0
  37. {xax-0.3.3 → xax-0.3.9}/xax/task/launchers/cli.py +0 -0
  38. {xax-0.3.3 → xax-0.3.9}/xax/task/launchers/single_process.py +0 -0
  39. {xax-0.3.3 → xax-0.3.9}/xax/task/logger.py +0 -0
  40. {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/__init__.py +0 -0
  41. {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/callback.py +0 -0
  42. {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/json.py +0 -0
  43. {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/state.py +0 -0
  44. {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/stdout.py +0 -0
  45. {xax-0.3.3 → xax-0.3.9}/xax/task/loggers/tensorboard.py +0 -0
  46. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/__init__.py +0 -0
  47. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/artifacts.py +0 -0
  48. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/checkpointing.py +0 -0
  49. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/compile.py +0 -0
  50. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/cpu_stats.py +0 -0
  51. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/gpu_stats.py +0 -0
  52. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/logger.py +0 -0
  53. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/process.py +0 -0
  54. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/runnable.py +0 -0
  55. {xax-0.3.3 → xax-0.3.9}/xax/task/mixins/step_wrapper.py +0 -0
  56. {xax-0.3.3 → xax-0.3.9}/xax/task/script.py +0 -0
  57. {xax-0.3.3 → xax-0.3.9}/xax/task/task.py +0 -0
  58. {xax-0.3.3 → xax-0.3.9}/xax/utils/__init__.py +0 -0
  59. {xax-0.3.3 → xax-0.3.9}/xax/utils/data/__init__.py +0 -0
  60. {xax-0.3.3 → xax-0.3.9}/xax/utils/data/collate.py +0 -0
  61. {xax-0.3.3 → xax-0.3.9}/xax/utils/debugging.py +0 -0
  62. {xax-0.3.3 → xax-0.3.9}/xax/utils/experiments.py +0 -0
  63. {xax-0.3.3 → xax-0.3.9}/xax/utils/jax.py +0 -0
  64. {xax-0.3.3 → xax-0.3.9}/xax/utils/jaxpr.py +0 -0
  65. {xax-0.3.3 → xax-0.3.9}/xax/utils/logging.py +0 -0
  66. {xax-0.3.3 → xax-0.3.9}/xax/utils/numpy.py +0 -0
  67. {xax-0.3.3 → xax-0.3.9}/xax/utils/profile.py +0 -0
  68. {xax-0.3.3 → xax-0.3.9}/xax/utils/tensorboard.py +0 -0
  69. {xax-0.3.3 → xax-0.3.9}/xax/utils/text.py +0 -0
  70. {xax-0.3.3 → xax-0.3.9}/xax/utils/types/__init__.py +0 -0
  71. {xax-0.3.3 → xax-0.3.9}/xax/utils/types/frozen_dict.py +0 -0
  72. {xax-0.3.3 → xax-0.3.9}/xax/utils/types/hashable_array.py +0 -0
  73. {xax-0.3.3 → xax-0.3.9}/xax.egg-info/dependency_links.txt +0 -0
  74. {xax-0.3.3 → xax-0.3.9}/xax.egg-info/entry_points.txt +0 -0
  75. {xax-0.3.3 → xax-0.3.9}/xax.egg-info/requires.txt +0 -0
  76. {xax-0.3.3 → xax-0.3.9}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.3
3
+ Version: 0.3.9
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.3"
15
+ __version__ = "0.3.9"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -23,10 +23,18 @@ __all__ = [
23
23
  "get_run_dir",
24
24
  "load_user_config",
25
25
  "State",
26
+ "AttentionCache",
27
+ "AttentionCacheDict",
26
28
  "CrossAttentionBlock",
27
29
  "SelfAttentionBlock",
28
30
  "Transformer",
29
31
  "TransformerBlock",
32
+ "TransformerCache",
33
+ "TransformerStack",
34
+ "Categorical",
35
+ "Distribution",
36
+ "MixtureOfGaussians",
37
+ "Normal",
30
38
  "FourierEmbeddings",
31
39
  "IdentityPositionalEmbeddings",
32
40
  "LearnedPositionalEmbeddings",
@@ -132,6 +140,7 @@ __all__ = [
132
140
  "compute_nan_ratio",
133
141
  "flatten_array",
134
142
  "flatten_pytree",
143
+ "get_pytree_mapping",
135
144
  "get_pytree_param_count",
136
145
  "pytree_has_nans",
137
146
  "reshuffle_pytree",
@@ -206,10 +215,18 @@ NAME_MAP: dict[str, str] = {
206
215
  "get_run_dir": "core.conf",
207
216
  "load_user_config": "core.conf",
208
217
  "State": "core.state",
218
+ "AttentionCache": "nn.attention",
219
+ "AttentionCacheDict": "nn.attention",
209
220
  "CrossAttentionBlock": "nn.attention",
210
221
  "SelfAttentionBlock": "nn.attention",
211
222
  "Transformer": "nn.attention",
212
223
  "TransformerBlock": "nn.attention",
224
+ "TransformerCache": "nn.attention",
225
+ "TransformerStack": "nn.attention",
226
+ "Categorical": "nn.distributions",
227
+ "Distribution": "nn.distributions",
228
+ "MixtureOfGaussians": "nn.distributions",
229
+ "Normal": "nn.distributions",
213
230
  "FourierEmbeddings": "nn.embeddings",
214
231
  "IdentityPositionalEmbeddings": "nn.embeddings",
215
232
  "LearnedPositionalEmbeddings": "nn.embeddings",
@@ -315,6 +332,7 @@ NAME_MAP: dict[str, str] = {
315
332
  "compute_nan_ratio": "utils.pytree",
316
333
  "flatten_array": "utils.pytree",
317
334
  "flatten_pytree": "utils.pytree",
335
+ "get_pytree_mapping": "utils.pytree",
318
336
  "get_pytree_param_count": "utils.pytree",
319
337
  "pytree_has_nans": "utils.pytree",
320
338
  "reshuffle_pytree": "utils.pytree",
@@ -362,6 +380,9 @@ NAME_MAP.update(
362
380
  },
363
381
  )
364
382
 
383
+ # In NAME_MAP
384
+ NAME_MAP["TransformerStack"] = "nn.attention"
385
+
365
386
 
366
387
  def __getattr__(name: str) -> object:
367
388
  if name not in NAME_MAP:
@@ -382,7 +403,17 @@ if IMPORT_ALL or TYPE_CHECKING:
382
403
  load_user_config,
383
404
  )
384
405
  from xax.core.state import Phase, State
385
- from xax.nn.attention import CrossAttentionBlock, SelfAttentionBlock, Transformer, TransformerBlock
406
+ from xax.nn.attention import (
407
+ AttentionCache,
408
+ AttentionCacheDict,
409
+ CrossAttentionBlock,
410
+ SelfAttentionBlock,
411
+ Transformer,
412
+ TransformerBlock,
413
+ TransformerCache,
414
+ TransformerStack,
415
+ )
416
+ from xax.nn.distributions import Categorical, Distribution, MixtureOfGaussians, Normal
386
417
  from xax.nn.embeddings import (
387
418
  EmbeddingKind,
388
419
  FourierEmbeddings,
@@ -411,12 +442,7 @@ if IMPORT_ALL or TYPE_CHECKING:
411
442
  rotation_matrix_to_rotation6d,
412
443
  )
413
444
  from xax.nn.losses import cross_entropy
414
- from xax.nn.metrics import (
415
- NormType,
416
- cast_norm_type,
417
- dynamic_time_warping,
418
- get_norm,
419
- )
445
+ from xax.nn.metrics import NormType, cast_norm_type, dynamic_time_warping, get_norm
420
446
  from xax.nn.parallel import is_master
421
447
  from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
422
448
  from xax.task.base import RawConfigType
@@ -494,6 +520,7 @@ if IMPORT_ALL or TYPE_CHECKING:
494
520
  compute_nan_ratio,
495
521
  flatten_array,
496
522
  flatten_pytree,
523
+ get_pytree_mapping,
497
524
  get_pytree_param_count,
498
525
  pytree_has_nans,
499
526
  reshuffle_pytree,