xax 0.2.21__tar.gz → 0.2.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {xax-0.2.21/xax.egg-info → xax-0.2.23}/PKG-INFO +1 -1
  2. {xax-0.2.21 → xax-0.2.23}/pyproject.toml +1 -0
  3. {xax-0.2.21 → xax-0.2.23}/xax/__init__.py +15 -2
  4. {xax-0.2.21 → xax-0.2.23}/xax/core/state.py +10 -37
  5. xax-0.2.23/xax/nn/attention.py +738 -0
  6. {xax-0.2.21 → xax-0.2.23}/xax/task/logger.py +1 -1
  7. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/train.py +17 -19
  8. {xax-0.2.21 → xax-0.2.23}/xax/utils/experiments.py +2 -2
  9. {xax-0.2.21 → xax-0.2.23}/xax/utils/jax.py +109 -7
  10. {xax-0.2.21 → xax-0.2.23/xax.egg-info}/PKG-INFO +1 -1
  11. {xax-0.2.21 → xax-0.2.23}/xax.egg-info/SOURCES.txt +1 -0
  12. {xax-0.2.21 → xax-0.2.23}/LICENSE +0 -0
  13. {xax-0.2.21 → xax-0.2.23}/MANIFEST.in +0 -0
  14. {xax-0.2.21 → xax-0.2.23}/README.md +0 -0
  15. {xax-0.2.21 → xax-0.2.23}/setup.cfg +0 -0
  16. {xax-0.2.21 → xax-0.2.23}/setup.py +0 -0
  17. {xax-0.2.21 → xax-0.2.23}/xax/cli/__init__.py +0 -0
  18. {xax-0.2.21 → xax-0.2.23}/xax/cli/edit_config.py +0 -0
  19. {xax-0.2.21 → xax-0.2.23}/xax/core/__init__.py +0 -0
  20. {xax-0.2.21 → xax-0.2.23}/xax/core/conf.py +0 -0
  21. {xax-0.2.21 → xax-0.2.23}/xax/nn/__init__.py +0 -0
  22. {xax-0.2.21 → xax-0.2.23}/xax/nn/embeddings.py +0 -0
  23. {xax-0.2.21 → xax-0.2.23}/xax/nn/functions.py +0 -0
  24. {xax-0.2.21 → xax-0.2.23}/xax/nn/geom.py +0 -0
  25. {xax-0.2.21 → xax-0.2.23}/xax/nn/losses.py +0 -0
  26. {xax-0.2.21 → xax-0.2.23}/xax/nn/metrics.py +0 -0
  27. {xax-0.2.21 → xax-0.2.23}/xax/nn/parallel.py +0 -0
  28. {xax-0.2.21 → xax-0.2.23}/xax/nn/ssm.py +0 -0
  29. {xax-0.2.21 → xax-0.2.23}/xax/py.typed +0 -0
  30. {xax-0.2.21 → xax-0.2.23}/xax/requirements-dev.txt +0 -0
  31. {xax-0.2.21 → xax-0.2.23}/xax/requirements.txt +0 -0
  32. {xax-0.2.21 → xax-0.2.23}/xax/task/__init__.py +0 -0
  33. {xax-0.2.21 → xax-0.2.23}/xax/task/base.py +0 -0
  34. {xax-0.2.21 → xax-0.2.23}/xax/task/launchers/__init__.py +0 -0
  35. {xax-0.2.21 → xax-0.2.23}/xax/task/launchers/base.py +0 -0
  36. {xax-0.2.21 → xax-0.2.23}/xax/task/launchers/cli.py +0 -0
  37. {xax-0.2.21 → xax-0.2.23}/xax/task/launchers/single_process.py +0 -0
  38. {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/__init__.py +0 -0
  39. {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/callback.py +0 -0
  40. {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/json.py +0 -0
  41. {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/state.py +0 -0
  42. {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/stdout.py +0 -0
  43. {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/tensorboard.py +0 -0
  44. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/__init__.py +0 -0
  45. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/artifacts.py +0 -0
  46. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/checkpointing.py +0 -0
  47. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/compile.py +0 -0
  48. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/cpu_stats.py +0 -0
  49. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/data_loader.py +0 -0
  50. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/gpu_stats.py +0 -0
  51. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/logger.py +0 -0
  52. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/process.py +0 -0
  53. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/runnable.py +0 -0
  54. {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/step_wrapper.py +0 -0
  55. {xax-0.2.21 → xax-0.2.23}/xax/task/script.py +0 -0
  56. {xax-0.2.21 → xax-0.2.23}/xax/task/task.py +0 -0
  57. {xax-0.2.21 → xax-0.2.23}/xax/utils/__init__.py +0 -0
  58. {xax-0.2.21 → xax-0.2.23}/xax/utils/data/__init__.py +0 -0
  59. {xax-0.2.21 → xax-0.2.23}/xax/utils/data/collate.py +0 -0
  60. {xax-0.2.21 → xax-0.2.23}/xax/utils/debugging.py +0 -0
  61. {xax-0.2.21 → xax-0.2.23}/xax/utils/jaxpr.py +0 -0
  62. {xax-0.2.21 → xax-0.2.23}/xax/utils/logging.py +0 -0
  63. {xax-0.2.21 → xax-0.2.23}/xax/utils/numpy.py +0 -0
  64. {xax-0.2.21 → xax-0.2.23}/xax/utils/profile.py +0 -0
  65. {xax-0.2.21 → xax-0.2.23}/xax/utils/pytree.py +0 -0
  66. {xax-0.2.21 → xax-0.2.23}/xax/utils/tensorboard.py +0 -0
  67. {xax-0.2.21 → xax-0.2.23}/xax/utils/text.py +0 -0
  68. {xax-0.2.21 → xax-0.2.23}/xax/utils/types/__init__.py +0 -0
  69. {xax-0.2.21 → xax-0.2.23}/xax/utils/types/frozen_dict.py +0 -0
  70. {xax-0.2.21 → xax-0.2.23}/xax/utils/types/hashable_array.py +0 -0
  71. {xax-0.2.21 → xax-0.2.23}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.2.21 → xax-0.2.23}/xax.egg-info/entry_points.txt +0 -0
  73. {xax-0.2.21 → xax-0.2.23}/xax.egg-info/requires.txt +0 -0
  74. {xax-0.2.21 → xax-0.2.23}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.21
3
+ Version: 0.2.23
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -38,6 +38,7 @@ module = [
38
38
  "optax.*",
39
39
  "setuptools.*",
40
40
  "tensorboard.*",
41
+ "tensorflow_datasets.*",
41
42
  "transformers.*",
42
43
  ]
43
44
 
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.21"
15
+ __version__ = "0.2.23"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -23,6 +23,10 @@ __all__ = [
23
23
  "get_run_dir",
24
24
  "load_user_config",
25
25
  "State",
26
+ "CrossAttentionBlock",
27
+ "SelfAttentionBlock",
28
+ "Transformer",
29
+ "TransformerBlock",
26
30
  "FourierEmbeddings",
27
31
  "IdentityPositionalEmbeddings",
28
32
  "LearnedPositionalEmbeddings",
@@ -112,8 +116,10 @@ __all__ = [
112
116
  "save_config",
113
117
  "stage_environment",
114
118
  "to_markdown_table",
119
+ "grad",
115
120
  "jit",
116
121
  "scan",
122
+ "vmap",
117
123
  "save_jaxpr_dot",
118
124
  "ColoredFormatter",
119
125
  "configure_logging",
@@ -198,6 +204,10 @@ NAME_MAP: dict[str, str] = {
198
204
  "get_run_dir": "core.conf",
199
205
  "load_user_config": "core.conf",
200
206
  "State": "core.state",
207
+ "CrossAttentionBlock": "nn.attention",
208
+ "SelfAttentionBlock": "nn.attention",
209
+ "Transformer": "nn.attention",
210
+ "TransformerBlock": "nn.attention",
201
211
  "FourierEmbeddings": "nn.embeddings",
202
212
  "IdentityPositionalEmbeddings": "nn.embeddings",
203
213
  "LearnedPositionalEmbeddings": "nn.embeddings",
@@ -287,8 +297,10 @@ NAME_MAP: dict[str, str] = {
287
297
  "save_config": "utils.experiments",
288
298
  "stage_environment": "utils.experiments",
289
299
  "to_markdown_table": "utils.experiments",
300
+ "grad": "utils.jax",
290
301
  "jit": "utils.jax",
291
302
  "scan": "utils.jax",
303
+ "vmap": "utils.jax",
292
304
  "save_jaxpr_dot": "utils.jaxpr",
293
305
  "ColoredFormatter": "utils.logging",
294
306
  "configure_logging": "utils.logging",
@@ -366,6 +378,7 @@ if IMPORT_ALL or TYPE_CHECKING:
366
378
  load_user_config,
367
379
  )
368
380
  from xax.core.state import Phase, State
381
+ from xax.nn.attention import CrossAttentionBlock, SelfAttentionBlock, Transformer, TransformerBlock
369
382
  from xax.nn.embeddings import (
370
383
  EmbeddingKind,
371
384
  FourierEmbeddings,
@@ -460,7 +473,7 @@ if IMPORT_ALL or TYPE_CHECKING:
460
473
  stage_environment,
461
474
  to_markdown_table,
462
475
  )
463
- from xax.utils.jax import jit, scan
476
+ from xax.utils.jax import grad, jit, scan, vmap
464
477
  from xax.utils.jaxpr import save_jaxpr_dot
465
478
  from xax.utils.logging import (
466
479
  LOG_ERROR_SUMMARY,
@@ -27,11 +27,8 @@ def _int_to_phase(i: int) -> Phase:
27
27
  class StateDict(TypedDict, total=False):
28
28
  num_steps: NotRequired[int | Array]
29
29
  num_samples: NotRequired[int | Array]
30
- num_valid_steps: NotRequired[int | Array]
31
- num_valid_samples: NotRequired[int | Array]
32
30
  start_time_s: NotRequired[float | Array]
33
31
  elapsed_time_s: NotRequired[float | Array]
34
- valid_elapsed_time_s: NotRequired[float | Array]
35
32
  phase: NotRequired[Phase]
36
33
  _phase: NotRequired[int | Array]
37
34
 
@@ -47,38 +44,26 @@ class State:
47
44
  return self._int32_arr[0]
48
45
 
49
46
  @property
50
- def num_valid_steps(self) -> Array:
51
- return self._int32_arr[1]
47
+ def phase(self) -> Phase:
48
+ return _int_to_phase(self._int32_arr[1].item())
52
49
 
53
50
  @property
54
51
  def num_samples(self) -> Array:
55
52
  return self._float32_arr[0]
56
53
 
57
- @property
58
- def num_valid_samples(self) -> Array:
59
- return self._float32_arr[1]
60
-
61
54
  @property
62
55
  def start_time_s(self) -> Array:
63
- return self._float32_arr[2]
56
+ return self._float32_arr[1]
64
57
 
65
58
  @property
66
59
  def elapsed_time_s(self) -> Array:
67
- return self._float32_arr[3]
68
-
69
- @property
70
- def valid_elapsed_time_s(self) -> Array:
71
- return self._float32_arr[4]
72
-
73
- @property
74
- def phase(self) -> Phase:
75
- return _int_to_phase(self._int32_arr[2].item())
60
+ return self._float32_arr[2]
76
61
 
77
62
  @classmethod
78
63
  def init_state(cls) -> "State":
79
64
  return cls(
80
- _int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
81
- _float32_arr=jnp.array([0.0, 0.0, time.time(), 0.0, 0.0], dtype=jnp.float32),
65
+ _int32_arr=jnp.array([0, 0], dtype=jnp.int32),
66
+ _float32_arr=jnp.array([0.0, time.time(), 0.0], dtype=jnp.float32),
82
67
  )
83
68
 
84
69
  @property
@@ -91,25 +76,19 @@ class State:
91
76
 
92
77
  if "num_steps" in kwargs:
93
78
  int32_arr = int32_arr.at[0].set(kwargs["num_steps"])
94
- if "num_valid_steps" in kwargs:
95
- int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
96
79
 
97
80
  if "phase" in kwargs:
98
- int32_arr = int32_arr.at[2].set(_phase_to_int(kwargs["phase"]))
81
+ int32_arr = int32_arr.at[1].set(_phase_to_int(kwargs["phase"]))
99
82
  if "_phase" in kwargs:
100
- int32_arr = int32_arr.at[2].set(kwargs["_phase"])
83
+ int32_arr = int32_arr.at[1].set(kwargs["_phase"])
101
84
 
102
85
  if "num_samples" in kwargs:
103
86
  float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
104
- if "num_valid_samples" in kwargs:
105
- float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
106
87
 
107
88
  if "start_time_s" in kwargs:
108
- float32_arr = float32_arr.at[2].set(kwargs["start_time_s"])
89
+ float32_arr = float32_arr.at[1].set(kwargs["start_time_s"])
109
90
  if "elapsed_time_s" in kwargs:
110
- float32_arr = float32_arr.at[3].set(kwargs["elapsed_time_s"])
111
- if "valid_elapsed_time_s" in kwargs:
112
- float32_arr = float32_arr.at[4].set(kwargs["valid_elapsed_time_s"])
91
+ float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
113
92
 
114
93
  return State(
115
94
  _int32_arr=int32_arr,
@@ -119,12 +98,9 @@ class State:
119
98
  def to_dict(self) -> dict[str, int | float | str]:
120
99
  return {
121
100
  "num_steps": int(self.num_steps.item()),
122
- "num_valid_steps": int(self.num_valid_steps.item()),
123
101
  "num_samples": int(self.num_samples.item()),
124
- "num_valid_samples": int(self.num_valid_samples.item()),
125
102
  "start_time_s": float(self.start_time_s.item()),
126
103
  "elapsed_time_s": float(self.elapsed_time_s.item()),
127
- "valid_elapsed_time_s": float(self.valid_elapsed_time_s.item()),
128
104
  "phase": str(self.phase),
129
105
  }
130
106
 
@@ -136,7 +112,6 @@ class State:
136
112
  int32_arr = jnp.array(
137
113
  [
138
114
  d.get("num_steps", 0),
139
- d.get("num_valid_steps", 0),
140
115
  d.get("_phase", 0),
141
116
  ],
142
117
  dtype=jnp.int32,
@@ -145,10 +120,8 @@ class State:
145
120
  float32_arr = jnp.array(
146
121
  [
147
122
  d.get("num_samples", 0),
148
- d.get("num_valid_samples", 0),
149
123
  d.get("start_time_s", time.time()),
150
124
  d.get("elapsed_time_s", 0.0),
151
- d.get("valid_elapsed_time_s", 0.0),
152
125
  ],
153
126
  dtype=jnp.float32,
154
127
  )