xax 0.1.11__tar.gz → 0.1.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {xax-0.1.11/xax.egg-info → xax-0.1.12}/PKG-INFO +1 -1
  2. {xax-0.1.11 → xax-0.1.12}/xax/__init__.py +9 -10
  3. {xax-0.1.11 → xax-0.1.12}/xax/nn/geom.py +57 -0
  4. xax-0.1.12/xax/nn/ssm.py +316 -0
  5. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/train.py +15 -7
  6. {xax-0.1.11 → xax-0.1.12/xax.egg-info}/PKG-INFO +1 -1
  7. xax-0.1.11/xax/nn/ssm.py +0 -296
  8. {xax-0.1.11 → xax-0.1.12}/LICENSE +0 -0
  9. {xax-0.1.11 → xax-0.1.12}/MANIFEST.in +0 -0
  10. {xax-0.1.11 → xax-0.1.12}/README.md +0 -0
  11. {xax-0.1.11 → xax-0.1.12}/pyproject.toml +0 -0
  12. {xax-0.1.11 → xax-0.1.12}/setup.cfg +0 -0
  13. {xax-0.1.11 → xax-0.1.12}/setup.py +0 -0
  14. {xax-0.1.11 → xax-0.1.12}/xax/core/__init__.py +0 -0
  15. {xax-0.1.11 → xax-0.1.12}/xax/core/conf.py +0 -0
  16. {xax-0.1.11 → xax-0.1.12}/xax/core/state.py +0 -0
  17. {xax-0.1.11 → xax-0.1.12}/xax/nn/__init__.py +0 -0
  18. {xax-0.1.11 → xax-0.1.12}/xax/nn/embeddings.py +0 -0
  19. {xax-0.1.11 → xax-0.1.12}/xax/nn/equinox.py +0 -0
  20. {xax-0.1.11 → xax-0.1.12}/xax/nn/export.py +0 -0
  21. {xax-0.1.11 → xax-0.1.12}/xax/nn/functions.py +0 -0
  22. {xax-0.1.11 → xax-0.1.12}/xax/nn/losses.py +0 -0
  23. {xax-0.1.11 → xax-0.1.12}/xax/nn/norm.py +0 -0
  24. {xax-0.1.11 → xax-0.1.12}/xax/nn/parallel.py +0 -0
  25. {xax-0.1.11 → xax-0.1.12}/xax/py.typed +0 -0
  26. {xax-0.1.11 → xax-0.1.12}/xax/requirements-dev.txt +0 -0
  27. {xax-0.1.11 → xax-0.1.12}/xax/requirements.txt +0 -0
  28. {xax-0.1.11 → xax-0.1.12}/xax/task/__init__.py +0 -0
  29. {xax-0.1.11 → xax-0.1.12}/xax/task/base.py +0 -0
  30. {xax-0.1.11 → xax-0.1.12}/xax/task/launchers/__init__.py +0 -0
  31. {xax-0.1.11 → xax-0.1.12}/xax/task/launchers/base.py +0 -0
  32. {xax-0.1.11 → xax-0.1.12}/xax/task/launchers/cli.py +0 -0
  33. {xax-0.1.11 → xax-0.1.12}/xax/task/launchers/single_process.py +0 -0
  34. {xax-0.1.11 → xax-0.1.12}/xax/task/logger.py +0 -0
  35. {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/__init__.py +0 -0
  36. {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/callback.py +0 -0
  37. {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/json.py +0 -0
  38. {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/state.py +0 -0
  39. {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/stdout.py +0 -0
  40. {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/tensorboard.py +0 -0
  41. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/__init__.py +0 -0
  42. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/artifacts.py +0 -0
  43. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/checkpointing.py +0 -0
  44. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/compile.py +0 -0
  45. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/cpu_stats.py +0 -0
  46. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/gpu_stats.py +0 -0
  48. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/process.py +0 -0
  50. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/runnable.py +0 -0
  51. {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/step_wrapper.py +0 -0
  52. {xax-0.1.11 → xax-0.1.12}/xax/task/script.py +0 -0
  53. {xax-0.1.11 → xax-0.1.12}/xax/task/task.py +0 -0
  54. {xax-0.1.11 → xax-0.1.12}/xax/utils/__init__.py +0 -0
  55. {xax-0.1.11 → xax-0.1.12}/xax/utils/data/__init__.py +0 -0
  56. {xax-0.1.11 → xax-0.1.12}/xax/utils/data/collate.py +0 -0
  57. {xax-0.1.11 → xax-0.1.12}/xax/utils/debugging.py +0 -0
  58. {xax-0.1.11 → xax-0.1.12}/xax/utils/experiments.py +0 -0
  59. {xax-0.1.11 → xax-0.1.12}/xax/utils/jax.py +0 -0
  60. {xax-0.1.11 → xax-0.1.12}/xax/utils/jaxpr.py +0 -0
  61. {xax-0.1.11 → xax-0.1.12}/xax/utils/logging.py +0 -0
  62. {xax-0.1.11 → xax-0.1.12}/xax/utils/numpy.py +0 -0
  63. {xax-0.1.11 → xax-0.1.12}/xax/utils/profile.py +0 -0
  64. {xax-0.1.11 → xax-0.1.12}/xax/utils/pytree.py +0 -0
  65. {xax-0.1.11 → xax-0.1.12}/xax/utils/tensorboard.py +0 -0
  66. {xax-0.1.11 → xax-0.1.12}/xax/utils/text.py +0 -0
  67. {xax-0.1.11 → xax-0.1.12}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.1.11 → xax-0.1.12}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.1.11 → xax-0.1.12}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.1.11 → xax-0.1.12}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.1.11 → xax-0.1.12}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.1.11 → xax-0.1.12}/xax.egg-info/requires.txt +0 -0
  73. {xax-0.1.11 → xax-0.1.12}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.11
3
+ Version: 0.1.12
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.11"
15
+ __version__ = "0.1.12"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -43,15 +43,14 @@ __all__ = [
43
43
  "euler_to_quat",
44
44
  "get_projected_gravity_vector_from_quat",
45
45
  "quat_to_euler",
46
+ "rotate_vector_by_quat",
46
47
  "cross_entropy",
47
48
  "cast_norm_type",
48
49
  "get_norm",
49
50
  "is_master",
51
+ "BaseSSMBlock",
50
52
  "DiagSSMBlock",
51
- "DiscreteTimeS4",
52
- "S4",
53
- "S4Layer",
54
- "S6Layer",
53
+ "SSM",
55
54
  "SSMBlock",
56
55
  "BaseLauncher",
57
56
  "CliLauncher",
@@ -203,15 +202,14 @@ NAME_MAP: dict[str, str] = {
203
202
  "euler_to_quat": "nn.geom",
204
203
  "get_projected_gravity_vector_from_quat": "nn.geom",
205
204
  "quat_to_euler": "nn.geom",
205
+ "rotate_vector_by_quat": "nn.geom",
206
206
  "cross_entropy": "nn.losses",
207
207
  "cast_norm_type": "nn.norm",
208
208
  "get_norm": "nn.norm",
209
209
  "is_master": "nn.parallel",
210
+ "BaseSSMBlock": "nn.ssm",
210
211
  "DiagSSMBlock": "nn.ssm",
211
- "DiscreteTimeS4": "nn.ssm",
212
- "S4": "nn.ssm",
213
- "S4Layer": "nn.ssm",
214
- "S6Layer": "nn.ssm",
212
+ "SSM": "nn.ssm",
215
213
  "SSMBlock": "nn.ssm",
216
214
  "BaseLauncher": "task.launchers.base",
217
215
  "CliLauncher": "task.launchers.cli",
@@ -364,11 +362,12 @@ if IMPORT_ALL or TYPE_CHECKING:
364
362
  euler_to_quat,
365
363
  get_projected_gravity_vector_from_quat,
366
364
  quat_to_euler,
365
+ rotate_vector_by_quat,
367
366
  )
368
367
  from xax.nn.losses import cross_entropy
369
368
  from xax.nn.norm import NormType, cast_norm_type, get_norm
370
369
  from xax.nn.parallel import is_master
371
- from xax.nn.ssm import S4, DiagSSMBlock, DiscreteTimeS4, S4Layer, S6Layer, SSMBlock
370
+ from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
372
371
  from xax.task.base import RawConfigType
373
372
  from xax.task.launchers.base import BaseLauncher
374
373
  from xax.task.launchers.cli import CliLauncher
@@ -99,3 +99,60 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
99
99
 
100
100
  # Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
101
101
  return jnp.concatenate([gx, gy, -gz], axis=-1)
102
+
103
+
104
+ def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
105
+ """Rotates a vector by a quaternion.
106
+
107
+ Args:
108
+ vector: The vector to rotate, shape (*, 3).
109
+ quat: The quaternion to rotate by, shape (*, 4).
110
+ eps: A small epsilon value to avoid division by zero.
111
+
112
+ Returns:
113
+ The rotated vector, shape (*, 3).
114
+ """
115
+ # Normalize quaternion
116
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
117
+ w, x, y, z = jnp.split(quat, 4, axis=-1)
118
+
119
+ # Extract vector components
120
+ vx, vy, vz = jnp.split(vector, 3, axis=-1)
121
+
122
+ # Terms for x component
123
+ xx = (
124
+ w * w * vx
125
+ + 2 * y * w * vz
126
+ - 2 * z * w * vy
127
+ + x * x * vx
128
+ + 2 * y * x * vy
129
+ + 2 * z * x * vz
130
+ - z * z * vx
131
+ - y * y * vx
132
+ )
133
+
134
+ # Terms for y component
135
+ yy = (
136
+ 2 * x * y * vx
137
+ + y * y * vy
138
+ + 2 * z * y * vz
139
+ + 2 * w * z * vx
140
+ - z * z * vy
141
+ + w * w * vy
142
+ - 2 * w * x * vz
143
+ - x * x * vy
144
+ )
145
+
146
+ # Terms for z component
147
+ zz = (
148
+ 2 * x * z * vx
149
+ + 2 * y * z * vy
150
+ + z * z * vz
151
+ - 2 * w * y * vx
152
+ + w * w * vz
153
+ + 2 * w * x * vy
154
+ - y * y * vz
155
+ - x * x * vz
156
+ )
157
+
158
+ return jnp.concatenate([xx, yy, zz], axis=-1)
@@ -0,0 +1,316 @@
1
+ """State space models."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Literal
5
+
6
+ import equinox as eqx
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jaxtyping import Array, PRNGKeyArray
10
+
11
+
12
+ def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
13
+ return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
14
+
15
+
16
+ class BaseSSMBlock(eqx.Module, ABC):
17
+ @abstractmethod
18
+ def forward(self, h: Array, x: Array) -> Array: ...
19
+
20
+ @abstractmethod
21
+ def forward_sequence(self, x_seq: Array) -> Array: ...
22
+
23
+ @abstractmethod
24
+ def get_a_mat(self, x: Array) -> Array: ...
25
+
26
+ @abstractmethod
27
+ def get_b_mat(self, x: Array) -> Array: ...
28
+
29
+
30
+ class SSMBlock(BaseSSMBlock):
31
+ a_mat: Array
32
+ b_mat: Array
33
+
34
+ def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
35
+ key_a, key_b = jax.random.split(key)
36
+ self.a_mat = glorot(key_a, (hidden_size, hidden_size))
37
+ self.b_mat = glorot(key_b, (hidden_size, hidden_size))
38
+
39
+ def get_a_mat(self, x: Array) -> Array:
40
+ return self.a_mat
41
+
42
+ def get_b_mat(self, x: Array) -> Array:
43
+ return self.b_mat
44
+
45
+ def forward(self, h: Array, x: Array) -> Array:
46
+ """Perform a forward pass.
47
+
48
+ Args:
49
+ h: Hidden state of shape (H,).
50
+ x: Input of shape (H,).
51
+
52
+ Returns:
53
+ Hidden state of shape (H,).
54
+ """
55
+ a_mat = self.get_a_mat(x)
56
+ b_mat = self.get_b_mat(x)
57
+ h = a_mat @ h + b_mat.T @ x
58
+ return h
59
+
60
+ def forward_sequence(self, x_seq: Array) -> Array:
61
+ """Perform a forward pass across time.
62
+
63
+ Args:
64
+ x_seq: Input sequence of shape (T, H).
65
+
66
+ Returns:
67
+ Hidden state sequence of shape (T, H).
68
+ """
69
+
70
+ def step(h: Array, x: Array) -> tuple[Array, Array]:
71
+ h = self.forward(h, x)
72
+ return h, h
73
+
74
+ a_mat = self.get_a_mat(x_seq)
75
+ h_0 = jnp.zeros(a_mat.shape[0])
76
+ _, h_seq = jax.lax.scan(step, h_0, x_seq)
77
+ return h_seq
78
+
79
+
80
+ class DiagSSMBlock(BaseSSMBlock):
81
+ a_diag: Array
82
+ b_mat: Array
83
+
84
+ def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
85
+ keys = jax.random.split(key, 2)
86
+ self.a_diag = glorot(keys[0], (hidden_size,))
87
+ self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
88
+
89
+ def get_a_mat(self, x: Array) -> Array:
90
+ return self.a_diag
91
+
92
+ def get_b_mat(self, x: Array) -> Array:
93
+ return self.b_mat
94
+
95
+ def forward(self, h: Array, x: Array) -> Array:
96
+ """Perform a forward pass.
97
+
98
+ Args:
99
+ h: Hidden state of shape (H,).
100
+ x: Input of shape (H,).
101
+
102
+ Returns:
103
+ Hidden state of shape (H,).
104
+ """
105
+ a_diag = self.get_a_mat(x)
106
+ b_mat = self.get_b_mat(x)
107
+ h = a_diag * h + b_mat.T @ x
108
+ return h
109
+
110
+ def forward_sequence(self, x_seq: Array, *, use_conv: bool = True, recursive_kernel_calc: bool = False) -> Array:
111
+ """Perform a potentially parallelized forward pass across time.
112
+
113
+ Args:
114
+ x_seq: Input sequence of shape (T, H).
115
+ use_conv: Whether to use convolution to compute the sequence.
116
+ recursive_kernel_calc: Whether to use a recursive kernel calculation.
117
+
118
+ Returns:
119
+ Hidden state sequence of shape (T, H).
120
+ """
121
+ if use_conv:
122
+ return self._forward_sequence_conv(x_seq, recursive_kernel_calc=recursive_kernel_calc)
123
+ else:
124
+ return self._forward_sequence_scan(x_seq)
125
+
126
+ def _get_kernel(self, x_seq: Array, length: int) -> Array:
127
+ """Returns the kernel with time as the final dimension."""
128
+ exponents = jnp.arange(length)
129
+ a_diag = self.get_a_mat(x_seq)
130
+ kernel = jnp.power(a_diag[:, None], exponents) # (H, T)
131
+ kernel = kernel[:, None, :] # (H, 1, T)
132
+ return kernel
133
+
134
+ def _get_kernel_recursive(self, x_seq: Array, length: int) -> Array:
135
+ """Returns the kernel with time as the final dimension."""
136
+ assert length % 2 == 0, "Length must be even."
137
+ a_diag = self.get_a_mat(x_seq)
138
+
139
+ def helper(length: int) -> tuple[Array, Array]:
140
+ """Returns the kernel and the sqrt of the diagonal."""
141
+ if length == 1:
142
+ return jnp.ones_like(a_diag)[:, None], a_diag[:, None]
143
+
144
+ half_length = length // 2
145
+ kernel_half, a_half = helper(half_length)
146
+ kernel = jnp.concatenate([kernel_half, a_half * kernel_half], axis=-1)
147
+ return kernel, a_half * a_half
148
+
149
+ kernel, a_diag = helper(length)
150
+ return kernel[:, None, :] # (H, 1, L)
151
+
152
+ def _forward_sequence_conv(self, x_seq: Array, *, recursive_kernel_calc: bool = False) -> Array:
153
+ """Convolves x (T, H) across time using the kernel."""
154
+ seq_len, hidden_size = x_seq.shape
155
+ b_mat = self.get_b_mat(x_seq)
156
+
157
+ s = b_mat.T @ x_seq.T # (H, T)
158
+ s_padded = jnp.pad(s, ((0, 0), (seq_len - 1, 0)))[None, :, :] # (1, H, 2T-1)
159
+
160
+ if recursive_kernel_calc:
161
+ kernel = self._get_kernel_recursive(x_seq, seq_len)
162
+ else:
163
+ kernel = self._get_kernel(x_seq, seq_len)
164
+
165
+ kernel_flipped = jnp.flip(kernel, axis=-1) # (H, 1, L)
166
+
167
+ conv_out = jax.lax.conv_general_dilated(
168
+ s_padded,
169
+ kernel_flipped,
170
+ window_strides=(1,),
171
+ padding="VALID",
172
+ dimension_numbers=("NCT", "OIT", "NCT"), # convolving over time
173
+ feature_group_count=hidden_size,
174
+ )
175
+ conv_out = conv_out[0].T # (T, H)
176
+ return conv_out
177
+
178
+ def _forward_sequence_scan(self, x_seq: Array) -> Array:
179
+ """Naively forward across time."""
180
+
181
+ def step(h: Array, x: Array) -> tuple[Array, Array]:
182
+ h = self.forward(h, x)
183
+ return h, h
184
+
185
+ a_diag = self.get_a_mat(x_seq)
186
+ h_0 = jnp.zeros(a_diag.shape[0])
187
+ _, h_seq = jax.lax.scan(step, h_0, x_seq)
188
+ return h_seq
189
+
190
+
191
+ class DiscreteDiagSSMBlock(DiagSSMBlock):
192
+ delta: Array
193
+
194
+ def __init__(
195
+ self,
196
+ hidden_size: int,
197
+ *,
198
+ key: PRNGKeyArray,
199
+ init_delta: float = 1.0,
200
+ init_scale: float = 10.0,
201
+ ) -> None:
202
+ super().__init__(hidden_size, key=key)
203
+ self.delta = jnp.array(init_delta)
204
+
205
+ # A positive scale helps reduce the gradient at the start.
206
+ self.a_diag = jax.random.uniform(key, (hidden_size,), minval=-1.0, maxval=0.0) * init_scale
207
+
208
+ def get_a_mat(self, x: Array) -> Array:
209
+ """Discretize the diagonal matrix using zero-order hold."""
210
+ a_diag_discrete = jnp.exp(self.a_diag * self.delta)
211
+ return a_diag_discrete
212
+
213
+ def get_b_mat(self, x: Array) -> Array:
214
+ """Discretize the input matrix using zero-order hold."""
215
+ delta_a_diag = self.a_diag * self.delta
216
+ exp_a_diag = jnp.exp(delta_a_diag)
217
+ delta_a_inv = 1 / delta_a_diag
218
+ delta_b_mat = self.delta * self.b_mat
219
+
220
+ b_discrete = delta_a_inv * (exp_a_diag - 1) * delta_b_mat
221
+ return b_discrete
222
+
223
+
224
+ class SSM(eqx.Module):
225
+ vocab_embedding: eqx.nn.Embedding
226
+ output_layer: eqx.nn.Linear
227
+ blocks: list[BaseSSMBlock]
228
+ num_layers: int = eqx.static_field()
229
+ hidden_size: int = eqx.static_field()
230
+ skip_connections: bool = eqx.static_field()
231
+
232
+ def __init__(
233
+ self,
234
+ input_size: int,
235
+ hidden_size: int,
236
+ output_size: int,
237
+ num_layers: int,
238
+ block_type: Literal["diagonal", "full_rank"] = "full_rank",
239
+ skip_connections: bool = False,
240
+ discretize: bool = False,
241
+ *,
242
+ key: PRNGKeyArray,
243
+ ) -> None:
244
+ vocab_key, s4_key = jax.random.split(key, 2)
245
+ self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
246
+ self.output_layer = eqx.nn.Linear(hidden_size, output_size, key=key)
247
+
248
+ block_keys = jax.random.split(s4_key, num_layers)
249
+
250
+ def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
251
+ match block_type:
252
+ case "diagonal":
253
+ return (
254
+ DiscreteDiagSSMBlock(hidden_size, key=key, init_delta=0.1)
255
+ if discretize
256
+ else DiagSSMBlock(hidden_size, key=key)
257
+ )
258
+ case "full_rank":
259
+ if discretize:
260
+ raise ValueError("Full rank blocks do not support discretization due to instability.")
261
+ return SSMBlock(hidden_size, key=key)
262
+ case _:
263
+ raise ValueError(f"Unknown block type: {block_type}")
264
+
265
+ self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
266
+ self.skip_connections = skip_connections
267
+ self.num_layers = num_layers
268
+ self.hidden_size = hidden_size
269
+
270
+ def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
271
+ new_hs = []
272
+ for i, block in enumerate(self.blocks):
273
+ h = block.forward(hs[i], x)
274
+ new_hs.append(h)
275
+ xh = jax.nn.gelu(h)
276
+ x = xh + x if self.skip_connections else xh
277
+ y = self.output_layer(x)
278
+ return new_hs, y
279
+
280
+ def _embed_input(self, x: Array) -> Array:
281
+ """U is the input to the S4 cell."""
282
+ return self.vocab_embedding(x)
283
+
284
+ def predict_sequence(self, x_seq: Array) -> Array:
285
+ x_emb = jax.vmap(self._embed_input)(x_seq)
286
+ for block in self.blocks:
287
+ h = block.forward_sequence(x_emb)
288
+ # h = block.naive_forward_sequence(x_emb)
289
+ h = jax.nn.gelu(h)
290
+ x_emb = h + x_emb if self.skip_connections else h
291
+ y = jax.vmap(self.output_layer)(x_emb)
292
+ return y
293
+
294
+ def generate_sequence(self, prompt_seq: Array, max_len: int) -> Array:
295
+ hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
296
+ prompt_seq_embedded = jax.vmap(self._embed_input)(prompt_seq)
297
+
298
+ def encode_step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
299
+ hs, y = self(hs, x)
300
+ return hs, y
301
+
302
+ def decode_step(
303
+ carry: tuple[list[Array], Array, PRNGKeyArray],
304
+ _: None,
305
+ ) -> tuple[tuple[list[Array], Array, PRNGKeyArray], Array]:
306
+ hs, last_token, rng = carry
307
+ token_embedded = self._embed_input(last_token)
308
+ hs, y = self(hs, token_embedded)
309
+ token = jax.random.categorical(rng, y)
310
+ rng = jax.random.split(rng)[0]
311
+ return (hs, token, rng), token
312
+
313
+ hs, _ = jax.lax.scan(encode_step, hs, prompt_seq_embedded)
314
+ _, sequence = jax.lax.scan(decode_step, (hs, prompt_seq[-1], jax.random.PRNGKey(0)), None, length=max_len)
315
+
316
+ return sequence
@@ -218,26 +218,32 @@ class TrainMixin(
218
218
  state = super().on_step_end(state)
219
219
  return state.replace(elapsed_time_s=time.time() - state.start_time_s)
220
220
 
221
- def log_train_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
221
+ def log_train_step(
222
+ self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
223
+ ) -> None:
222
224
  """Override this function to do logging during the training phase.
223
225
 
224
226
  This function is called after the model forward pass and before the
225
227
  backward pass. It is called in the training phase.
226
228
 
227
229
  Args:
230
+ model: The current model.
228
231
  batch: The batch from the dataloader.
229
232
  output: The model output.
230
233
  metrics: The metrics for the current batch.
231
234
  state: The current training state.
232
235
  """
233
236
 
234
- def log_valid_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
237
+ def log_valid_step(
238
+ self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
239
+ ) -> None:
235
240
  """Override this function to do logging during the validation phase.
236
241
 
237
242
  This function is called after the model forward pass. It is called in
238
243
  the validation phase.
239
244
 
240
245
  Args:
246
+ model: The current model.
241
247
  batch: The batch from the dataloader.
242
248
  output: The model output.
243
249
  metrics: The metrics for the current batch.
@@ -251,7 +257,9 @@ class TrainMixin(
251
257
  for k, v in d.items():
252
258
  self.logger.log_scalar(k, v, namespace=ns)
253
259
 
254
- def log_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
260
+ def log_step(
261
+ self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
262
+ ) -> None:
255
263
  phase = state.phase
256
264
 
257
265
  for k, v in metrics.items():
@@ -265,9 +273,9 @@ class TrainMixin(
265
273
  # Delegate to the appropriate logging function based on the phase.
266
274
  match phase:
267
275
  case "train":
268
- self.log_train_step(batch, output, metrics, state)
276
+ self.log_train_step(model, batch, output, metrics, state)
269
277
  case "valid":
270
- self.log_valid_step(batch, output, metrics, state)
278
+ self.log_valid_step(model, batch, output, metrics, state)
271
279
  case _:
272
280
  raise KeyError(f"Unknown phase: {phase}")
273
281
 
@@ -579,7 +587,7 @@ class TrainMixin(
579
587
  )
580
588
 
581
589
  output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
582
- self.log_step(valid_batch, output, metrics, state)
590
+ self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
583
591
 
584
592
  state = self.on_step_start(state)
585
593
  train_batch = next(train_pf)
@@ -597,7 +605,7 @@ class TrainMixin(
597
605
  batch=train_batch,
598
606
  state=state,
599
607
  )
600
- self.log_step(train_batch, output, metrics, state)
608
+ self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
601
609
 
602
610
  state = self.on_step_end(state)
603
611
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.11
3
+ Version: 0.1.12
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
xax-0.1.11/xax/nn/ssm.py DELETED
@@ -1,296 +0,0 @@
1
- """State space models."""
2
-
3
- from abc import ABC, abstractmethod
4
- from typing import Literal
5
-
6
- import equinox as eqx
7
- import jax
8
- import jax.numpy as jnp
9
- from jaxtyping import Array, PRNGKeyArray
10
-
11
-
12
- def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
13
- return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
14
-
15
-
16
- class DiscreteTimeS4(eqx.Module):
17
- a: Array
18
- B: Array
19
- C: Array
20
- proj_in: eqx.nn.Linear
21
- proj_out: eqx.nn.Linear
22
-
23
- def __init__(
24
- self,
25
- hidden_size: int,
26
- projection_size: int,
27
- input_size: int,
28
- output_size: int,
29
- *,
30
- key: PRNGKeyArray,
31
- ) -> None:
32
- self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
33
- self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
34
- self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
35
- self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
36
- self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
37
-
38
- def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
39
- h = self.a * h + self.B.T @ x
40
- y = self.C.T @ h
41
- return h, y
42
-
43
- def predict_sequence(self, x_seq: Array) -> Array:
44
- x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
45
- h = jnp.zeros(self.a.shape[0])
46
-
47
- def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
48
- h = self.a * h + self.B.T @ x
49
- y = self.C.T @ h
50
- return h, y
51
-
52
- _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
53
- y_out = jax.vmap(self.proj_out)(y_seq)
54
- return y_out
55
-
56
-
57
- class S4Layer(eqx.Module):
58
- a: Array
59
- B: Array
60
- C: Array
61
- proj_in: eqx.nn.Linear
62
- proj_out: eqx.nn.Linear
63
- delta: Array
64
-
65
- def __init__(
66
- self,
67
- hidden_size: int,
68
- projection_size: int,
69
- input_size: int,
70
- output_size: int,
71
- *,
72
- key: PRNGKeyArray,
73
- ) -> None:
74
- self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
75
- self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
76
- self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
77
- self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
78
- self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
79
- self.delta = jax.random.uniform(key, (hidden_size,))
80
-
81
- def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
82
- delta_a = self.delta * self.a
83
- a_bar = jnp.exp(delta_a)
84
- b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
85
- h = a_bar * h + b_bar.T @ x
86
- y = self.C.T @ h
87
- return h, y
88
-
89
- def predict_sequence(self, x_seq: Array) -> Array:
90
- x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
91
- h = jnp.zeros(self.a.shape[0])
92
-
93
- def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
94
- h = self.a * h + self.B.T @ x
95
- y = self.C.T @ h
96
- return h, y
97
-
98
- _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
99
- y_out = jax.vmap(self.proj_out)(y_seq)
100
- return y_out
101
-
102
-
103
- class S6Layer(eqx.Module):
104
- a: Array
105
- B: Array
106
- C: Array
107
- proj_in: eqx.nn.Linear
108
- proj_out: eqx.nn.Linear
109
- delta: Array
110
-
111
- def __init__(
112
- self,
113
- hidden_size: int,
114
- projection_size: int,
115
- input_size: int,
116
- output_size: int,
117
- *,
118
- key: PRNGKeyArray,
119
- ) -> None:
120
- self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
121
- self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
122
- self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
123
- self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
124
- self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
125
- self.delta = jax.random.uniform(key, (hidden_size,))
126
-
127
- def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
128
- h = self.a * h + self.B.T @ x
129
- y = self.C.T @ h
130
- return h, y
131
-
132
- def predict_sequence(self, x_seq: Array) -> Array:
133
- x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
134
- h = jnp.zeros(self.a.shape[0])
135
-
136
- def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
137
- h = self.a * h + self.B.T @ x
138
- y = self.C.T @ h
139
- return h, y
140
-
141
- _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
142
- y_out = jax.vmap(self.proj_out)(y_seq)
143
- return y_out
144
-
145
-
146
- class BaseSSMBlock(eqx.Module, ABC):
147
- @abstractmethod
148
- def forward(self, h: Array, x: Array) -> Array:
149
- pass
150
-
151
-
152
- class SSMBlock(BaseSSMBlock):
153
- a_mat: Array
154
- b_mat: Array
155
-
156
- def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
157
- key_a, key_b = jax.random.split(key)
158
- self.a_mat = glorot(key_a, (hidden_size, hidden_size))
159
- self.b_mat = glorot(key_b, (hidden_size, hidden_size))
160
-
161
- def forward(self, h: Array, x: Array) -> Array:
162
- h = self.a_mat @ h + self.b_mat.T @ x
163
- return h
164
-
165
- def get_kernel(self, length: int) -> Array:
166
- return self.a_mat
167
-
168
-
169
- class DiagSSMBlock(BaseSSMBlock):
170
- a_mat: Array
171
- b_mat: Array
172
-
173
- def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
174
- keys = jax.random.split(key, 2)
175
- self.a_mat = glorot(keys[0], (hidden_size,))
176
- self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
177
-
178
- def forward(self, h: Array, x: Array) -> Array:
179
- h = self.a_mat * h + self.b_mat.T @ x
180
- h = jax.nn.tanh(h)
181
- return h
182
-
183
- def get_kernel(self, length: int) -> Array:
184
- """Returns the kernel with time as the final dimension."""
185
- exponents = jnp.arange(length)
186
- kernel = jnp.power(self.a_mat[:, None], exponents) # (H, L)
187
- kernel = kernel[:, None, :] # (H, 1, L)
188
- return kernel
189
-
190
- def forward_across_time(self, x: Array) -> Array:
191
- """Convolves x (T, H) across time using the kernel."""
192
- tsz, nhid = x.shape
193
-
194
- # Compute s = x @ U.T + b, with shape (N, T, H)
195
- s = self.b_mat.T @ x
196
- s = s.T # (H, T)
197
-
198
- kernel = self.get_kernel(tsz) # (H, 1, T)
199
- kernel_flipped = jnp.flip(kernel, axis=-1)
200
-
201
- # Pad s on the left along the time axis (pad length T-1)
202
- s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
203
-
204
- # Perform depthwise (grouped) 1D convolution.
205
- # We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
206
- # The dimension_numbers are chosen so that the channel dimension is second.
207
- conv_out = jax.lax.conv_general_dilated(
208
- s_padded,
209
- kernel_flipped,
210
- window_strides=(1,),
211
- padding="VALID",
212
- dimension_numbers=("NCH", "OIH", "NCH"),
213
- feature_group_count=nhid,
214
- )
215
- # conv_out has shape (N, H, T); transpose to (N, T, H)
216
- conv_out = jnp.transpose(conv_out, (0, 2, 1))
217
- return conv_out
218
-
219
- def naive_forward_accross_time(self, x: Array) -> Array:
220
- """Naively forward across time."""
221
-
222
- def step(h: Array, x: Array) -> tuple[Array, Array]:
223
- h = self.forward(h, x)
224
- return h, h
225
-
226
- h_0 = jnp.zeros(self.a_mat.shape[0])
227
- _, h_seq = jax.lax.scan(step, h_0, x)
228
- return h_seq
229
-
230
-
231
- class S4(eqx.Module):
232
- vocab_embedding: eqx.nn.Embedding
233
- proj_in: eqx.nn.Linear
234
- proj_out: eqx.nn.Linear
235
- blocks: list[BaseSSMBlock]
236
- num_layers: int = eqx.static_field()
237
- hidden_size: int = eqx.static_field()
238
- skip_connections: bool = eqx.static_field()
239
-
240
- def __init__(
241
- self,
242
- input_size: int,
243
- hidden_size: int,
244
- output_size: int,
245
- num_layers: int,
246
- block_type: Literal["ssm", "diag"] = "ssm",
247
- skip_connections: bool = False,
248
- *,
249
- key: PRNGKeyArray,
250
- ) -> None:
251
- vocab_key, s4_key = jax.random.split(key, 2)
252
- self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
253
- self.proj_in = eqx.nn.Linear(hidden_size, hidden_size, key=key)
254
- self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
255
-
256
- block_keys = jax.random.split(s4_key, num_layers)
257
-
258
- def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
259
- match block_type:
260
- case "ssm":
261
- return SSMBlock(hidden_size, key=key)
262
- case "diag":
263
- return DiagSSMBlock(hidden_size, key=key)
264
- case _:
265
- raise ValueError(f"Unknown block type: {block_type}")
266
-
267
- self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
268
- self.skip_connections = skip_connections
269
- self.num_layers = num_layers
270
- self.hidden_size = hidden_size
271
-
272
- def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
273
- new_hs = []
274
- for i, block in enumerate(self.blocks):
275
- h = block.forward(hs[i], x)
276
- new_hs.append(h)
277
- xh = jax.nn.gelu(h)
278
- x = xh + x if self.skip_connections else xh
279
- y = self.proj_out(x)
280
- return new_hs, y
281
-
282
- def _embed_input(self, x: Array) -> Array:
283
- """U is the input to the S4 cell."""
284
- embedded = self.vocab_embedding(x)
285
- return jax.nn.gelu(self.proj_in(embedded))
286
-
287
- def predict_sequence(self, x_seq: Array) -> Array:
288
- x_emb = jax.vmap(self._embed_input)(x_seq)
289
- hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
290
-
291
- def step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
292
- hs, y = self(hs, x)
293
- return hs, y
294
-
295
- _, y_seq = jax.lax.scan(step, hs, x_emb)
296
- return y_seq
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes