xax 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.11"
15
+ __version__ = "0.1.12"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -43,15 +43,14 @@ __all__ = [
43
43
  "euler_to_quat",
44
44
  "get_projected_gravity_vector_from_quat",
45
45
  "quat_to_euler",
46
+ "rotate_vector_by_quat",
46
47
  "cross_entropy",
47
48
  "cast_norm_type",
48
49
  "get_norm",
49
50
  "is_master",
51
+ "BaseSSMBlock",
50
52
  "DiagSSMBlock",
51
- "DiscreteTimeS4",
52
- "S4",
53
- "S4Layer",
54
- "S6Layer",
53
+ "SSM",
55
54
  "SSMBlock",
56
55
  "BaseLauncher",
57
56
  "CliLauncher",
@@ -203,15 +202,14 @@ NAME_MAP: dict[str, str] = {
203
202
  "euler_to_quat": "nn.geom",
204
203
  "get_projected_gravity_vector_from_quat": "nn.geom",
205
204
  "quat_to_euler": "nn.geom",
205
+ "rotate_vector_by_quat": "nn.geom",
206
206
  "cross_entropy": "nn.losses",
207
207
  "cast_norm_type": "nn.norm",
208
208
  "get_norm": "nn.norm",
209
209
  "is_master": "nn.parallel",
210
+ "BaseSSMBlock": "nn.ssm",
210
211
  "DiagSSMBlock": "nn.ssm",
211
- "DiscreteTimeS4": "nn.ssm",
212
- "S4": "nn.ssm",
213
- "S4Layer": "nn.ssm",
214
- "S6Layer": "nn.ssm",
212
+ "SSM": "nn.ssm",
215
213
  "SSMBlock": "nn.ssm",
216
214
  "BaseLauncher": "task.launchers.base",
217
215
  "CliLauncher": "task.launchers.cli",
@@ -364,11 +362,12 @@ if IMPORT_ALL or TYPE_CHECKING:
364
362
  euler_to_quat,
365
363
  get_projected_gravity_vector_from_quat,
366
364
  quat_to_euler,
365
+ rotate_vector_by_quat,
367
366
  )
368
367
  from xax.nn.losses import cross_entropy
369
368
  from xax.nn.norm import NormType, cast_norm_type, get_norm
370
369
  from xax.nn.parallel import is_master
371
- from xax.nn.ssm import S4, DiagSSMBlock, DiscreteTimeS4, S4Layer, S6Layer, SSMBlock
370
+ from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
372
371
  from xax.task.base import RawConfigType
373
372
  from xax.task.launchers.base import BaseLauncher
374
373
  from xax.task.launchers.cli import CliLauncher
xax/nn/geom.py CHANGED
@@ -99,3 +99,60 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
99
99
 
100
100
  # Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
101
101
  return jnp.concatenate([gx, gy, -gz], axis=-1)
102
+
103
+
104
+ def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
105
+ """Rotates a vector by a quaternion.
106
+
107
+ Args:
108
+ vector: The vector to rotate, shape (*, 3).
109
+ quat: The quaternion to rotate by, shape (*, 4).
110
+ eps: A small epsilon value to avoid division by zero.
111
+
112
+ Returns:
113
+ The rotated vector, shape (*, 3).
114
+ """
115
+ # Normalize quaternion
116
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
117
+ w, x, y, z = jnp.split(quat, 4, axis=-1)
118
+
119
+ # Extract vector components
120
+ vx, vy, vz = jnp.split(vector, 3, axis=-1)
121
+
122
+ # Terms for x component
123
+ xx = (
124
+ w * w * vx
125
+ + 2 * y * w * vz
126
+ - 2 * z * w * vy
127
+ + x * x * vx
128
+ + 2 * y * x * vy
129
+ + 2 * z * x * vz
130
+ - z * z * vx
131
+ - y * y * vx
132
+ )
133
+
134
+ # Terms for y component
135
+ yy = (
136
+ 2 * x * y * vx
137
+ + y * y * vy
138
+ + 2 * z * y * vz
139
+ + 2 * w * z * vx
140
+ - z * z * vy
141
+ + w * w * vy
142
+ - 2 * w * x * vz
143
+ - x * x * vy
144
+ )
145
+
146
+ # Terms for z component
147
+ zz = (
148
+ 2 * x * z * vx
149
+ + 2 * y * z * vy
150
+ + z * z * vz
151
+ - 2 * w * y * vx
152
+ + w * w * vz
153
+ + 2 * w * x * vy
154
+ - y * y * vz
155
+ - x * x * vz
156
+ )
157
+
158
+ return jnp.concatenate([xx, yy, zz], axis=-1)
xax/nn/ssm.py CHANGED
@@ -13,140 +13,18 @@ def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
13
13
  return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
14
14
 
15
15
 
16
- class DiscreteTimeS4(eqx.Module):
17
- a: Array
18
- B: Array
19
- C: Array
20
- proj_in: eqx.nn.Linear
21
- proj_out: eqx.nn.Linear
22
-
23
- def __init__(
24
- self,
25
- hidden_size: int,
26
- projection_size: int,
27
- input_size: int,
28
- output_size: int,
29
- *,
30
- key: PRNGKeyArray,
31
- ) -> None:
32
- self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
33
- self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
34
- self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
35
- self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
36
- self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
37
-
38
- def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
39
- h = self.a * h + self.B.T @ x
40
- y = self.C.T @ h
41
- return h, y
42
-
43
- def predict_sequence(self, x_seq: Array) -> Array:
44
- x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
45
- h = jnp.zeros(self.a.shape[0])
46
-
47
- def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
48
- h = self.a * h + self.B.T @ x
49
- y = self.C.T @ h
50
- return h, y
51
-
52
- _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
53
- y_out = jax.vmap(self.proj_out)(y_seq)
54
- return y_out
55
-
56
-
57
- class S4Layer(eqx.Module):
58
- a: Array
59
- B: Array
60
- C: Array
61
- proj_in: eqx.nn.Linear
62
- proj_out: eqx.nn.Linear
63
- delta: Array
64
-
65
- def __init__(
66
- self,
67
- hidden_size: int,
68
- projection_size: int,
69
- input_size: int,
70
- output_size: int,
71
- *,
72
- key: PRNGKeyArray,
73
- ) -> None:
74
- self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
75
- self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
76
- self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
77
- self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
78
- self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
79
- self.delta = jax.random.uniform(key, (hidden_size,))
80
-
81
- def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
82
- delta_a = self.delta * self.a
83
- a_bar = jnp.exp(delta_a)
84
- b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
85
- h = a_bar * h + b_bar.T @ x
86
- y = self.C.T @ h
87
- return h, y
88
-
89
- def predict_sequence(self, x_seq: Array) -> Array:
90
- x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
91
- h = jnp.zeros(self.a.shape[0])
92
-
93
- def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
94
- h = self.a * h + self.B.T @ x
95
- y = self.C.T @ h
96
- return h, y
97
-
98
- _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
99
- y_out = jax.vmap(self.proj_out)(y_seq)
100
- return y_out
101
-
102
-
103
- class S6Layer(eqx.Module):
104
- a: Array
105
- B: Array
106
- C: Array
107
- proj_in: eqx.nn.Linear
108
- proj_out: eqx.nn.Linear
109
- delta: Array
110
-
111
- def __init__(
112
- self,
113
- hidden_size: int,
114
- projection_size: int,
115
- input_size: int,
116
- output_size: int,
117
- *,
118
- key: PRNGKeyArray,
119
- ) -> None:
120
- self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
121
- self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
122
- self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
123
- self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
124
- self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
125
- self.delta = jax.random.uniform(key, (hidden_size,))
126
-
127
- def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
128
- h = self.a * h + self.B.T @ x
129
- y = self.C.T @ h
130
- return h, y
131
-
132
- def predict_sequence(self, x_seq: Array) -> Array:
133
- x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
134
- h = jnp.zeros(self.a.shape[0])
135
-
136
- def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
137
- h = self.a * h + self.B.T @ x
138
- y = self.C.T @ h
139
- return h, y
16
+ class BaseSSMBlock(eqx.Module, ABC):
17
+ @abstractmethod
18
+ def forward(self, h: Array, x: Array) -> Array: ...
140
19
 
141
- _, y_seq = jax.lax.scan(scan_fn, h, x_proj)
142
- y_out = jax.vmap(self.proj_out)(y_seq)
143
- return y_out
20
+ @abstractmethod
21
+ def forward_sequence(self, x_seq: Array) -> Array: ...
144
22
 
23
+ @abstractmethod
24
+ def get_a_mat(self, x: Array) -> Array: ...
145
25
 
146
- class BaseSSMBlock(eqx.Module, ABC):
147
26
  @abstractmethod
148
- def forward(self, h: Array, x: Array) -> Array:
149
- pass
27
+ def get_b_mat(self, x: Array) -> Array: ...
150
28
 
151
29
 
152
30
  class SSMBlock(BaseSSMBlock):
@@ -158,80 +36,194 @@ class SSMBlock(BaseSSMBlock):
158
36
  self.a_mat = glorot(key_a, (hidden_size, hidden_size))
159
37
  self.b_mat = glorot(key_b, (hidden_size, hidden_size))
160
38
 
39
+ def get_a_mat(self, x: Array) -> Array:
40
+ return self.a_mat
41
+
42
+ def get_b_mat(self, x: Array) -> Array:
43
+ return self.b_mat
44
+
161
45
  def forward(self, h: Array, x: Array) -> Array:
162
- h = self.a_mat @ h + self.b_mat.T @ x
46
+ """Perform a forward pass.
47
+
48
+ Args:
49
+ h: Hidden state of shape (H,).
50
+ x: Input of shape (H,).
51
+
52
+ Returns:
53
+ Hidden state of shape (H,).
54
+ """
55
+ a_mat = self.get_a_mat(x)
56
+ b_mat = self.get_b_mat(x)
57
+ h = a_mat @ h + b_mat.T @ x
163
58
  return h
164
59
 
165
- def get_kernel(self, length: int) -> Array:
166
- return self.a_mat
60
+ def forward_sequence(self, x_seq: Array) -> Array:
61
+ """Perform a forward pass across time.
62
+
63
+ Args:
64
+ x_seq: Input sequence of shape (T, H).
65
+
66
+ Returns:
67
+ Hidden state sequence of shape (T, H).
68
+ """
69
+
70
+ def step(h: Array, x: Array) -> tuple[Array, Array]:
71
+ h = self.forward(h, x)
72
+ return h, h
73
+
74
+ a_mat = self.get_a_mat(x_seq)
75
+ h_0 = jnp.zeros(a_mat.shape[0])
76
+ _, h_seq = jax.lax.scan(step, h_0, x_seq)
77
+ return h_seq
167
78
 
168
79
 
169
80
  class DiagSSMBlock(BaseSSMBlock):
170
- a_mat: Array
81
+ a_diag: Array
171
82
  b_mat: Array
172
83
 
173
84
  def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
174
85
  keys = jax.random.split(key, 2)
175
- self.a_mat = glorot(keys[0], (hidden_size,))
86
+ self.a_diag = glorot(keys[0], (hidden_size,))
176
87
  self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
177
88
 
89
+ def get_a_mat(self, x: Array) -> Array:
90
+ return self.a_diag
91
+
92
+ def get_b_mat(self, x: Array) -> Array:
93
+ return self.b_mat
94
+
178
95
  def forward(self, h: Array, x: Array) -> Array:
179
- h = self.a_mat * h + self.b_mat.T @ x
180
- h = jax.nn.tanh(h)
96
+ """Perform a forward pass.
97
+
98
+ Args:
99
+ h: Hidden state of shape (H,).
100
+ x: Input of shape (H,).
101
+
102
+ Returns:
103
+ Hidden state of shape (H,).
104
+ """
105
+ a_diag = self.get_a_mat(x)
106
+ b_mat = self.get_b_mat(x)
107
+ h = a_diag * h + b_mat.T @ x
181
108
  return h
182
109
 
183
- def get_kernel(self, length: int) -> Array:
110
+ def forward_sequence(self, x_seq: Array, *, use_conv: bool = True, recursive_kernel_calc: bool = False) -> Array:
111
+ """Perform a potentially parallelized forward pass across time.
112
+
113
+ Args:
114
+ x_seq: Input sequence of shape (T, H).
115
+ use_conv: Whether to use convolution to compute the sequence.
116
+ recursive_kernel_calc: Whether to use a recursive kernel calculation.
117
+
118
+ Returns:
119
+ Hidden state sequence of shape (T, H).
120
+ """
121
+ if use_conv:
122
+ return self._forward_sequence_conv(x_seq, recursive_kernel_calc=recursive_kernel_calc)
123
+ else:
124
+ return self._forward_sequence_scan(x_seq)
125
+
126
+ def _get_kernel(self, x_seq: Array, length: int) -> Array:
184
127
  """Returns the kernel with time as the final dimension."""
185
128
  exponents = jnp.arange(length)
186
- kernel = jnp.power(self.a_mat[:, None], exponents) # (H, L)
187
- kernel = kernel[:, None, :] # (H, 1, L)
129
+ a_diag = self.get_a_mat(x_seq)
130
+ kernel = jnp.power(a_diag[:, None], exponents) # (H, T)
131
+ kernel = kernel[:, None, :] # (H, 1, T)
188
132
  return kernel
189
133
 
190
- def forward_across_time(self, x: Array) -> Array:
134
+ def _get_kernel_recursive(self, x_seq: Array, length: int) -> Array:
135
+ """Returns the kernel with time as the final dimension."""
136
+ assert length % 2 == 0, "Length must be even."
137
+ a_diag = self.get_a_mat(x_seq)
138
+
139
+ def helper(length: int) -> tuple[Array, Array]:
140
+ """Returns the kernel and the sqrt of the diagonal."""
141
+ if length == 1:
142
+ return jnp.ones_like(a_diag)[:, None], a_diag[:, None]
143
+
144
+ half_length = length // 2
145
+ kernel_half, a_half = helper(half_length)
146
+ kernel = jnp.concatenate([kernel_half, a_half * kernel_half], axis=-1)
147
+ return kernel, a_half * a_half
148
+
149
+ kernel, a_diag = helper(length)
150
+ return kernel[:, None, :] # (H, 1, L)
151
+
152
+ def _forward_sequence_conv(self, x_seq: Array, *, recursive_kernel_calc: bool = False) -> Array:
191
153
  """Convolves x (T, H) across time using the kernel."""
192
- tsz, nhid = x.shape
154
+ seq_len, hidden_size = x_seq.shape
155
+ b_mat = self.get_b_mat(x_seq)
193
156
 
194
- # Compute s = x @ U.T + b, with shape (N, T, H)
195
- s = self.b_mat.T @ x
196
- s = s.T # (H, T)
157
+ s = b_mat.T @ x_seq.T # (H, T)
158
+ s_padded = jnp.pad(s, ((0, 0), (seq_len - 1, 0)))[None, :, :] # (1, H, 2T-1)
197
159
 
198
- kernel = self.get_kernel(tsz) # (H, 1, T)
199
- kernel_flipped = jnp.flip(kernel, axis=-1)
160
+ if recursive_kernel_calc:
161
+ kernel = self._get_kernel_recursive(x_seq, seq_len)
162
+ else:
163
+ kernel = self._get_kernel(x_seq, seq_len)
200
164
 
201
- # Pad s on the left along the time axis (pad length T-1)
202
- s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
165
+ kernel_flipped = jnp.flip(kernel, axis=-1) # (H, 1, L)
203
166
 
204
- # Perform depthwise (grouped) 1D convolution.
205
- # We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
206
- # The dimension_numbers are chosen so that the channel dimension is second.
207
167
  conv_out = jax.lax.conv_general_dilated(
208
168
  s_padded,
209
169
  kernel_flipped,
210
170
  window_strides=(1,),
211
171
  padding="VALID",
212
- dimension_numbers=("NCH", "OIH", "NCH"),
213
- feature_group_count=nhid,
172
+ dimension_numbers=("NCT", "OIT", "NCT"), # convolving over time
173
+ feature_group_count=hidden_size,
214
174
  )
215
- # conv_out has shape (N, H, T); transpose to (N, T, H)
216
- conv_out = jnp.transpose(conv_out, (0, 2, 1))
175
+ conv_out = conv_out[0].T # (T, H)
217
176
  return conv_out
218
177
 
219
- def naive_forward_accross_time(self, x: Array) -> Array:
178
+ def _forward_sequence_scan(self, x_seq: Array) -> Array:
220
179
  """Naively forward across time."""
221
180
 
222
181
  def step(h: Array, x: Array) -> tuple[Array, Array]:
223
182
  h = self.forward(h, x)
224
183
  return h, h
225
184
 
226
- h_0 = jnp.zeros(self.a_mat.shape[0])
227
- _, h_seq = jax.lax.scan(step, h_0, x)
185
+ a_diag = self.get_a_mat(x_seq)
186
+ h_0 = jnp.zeros(a_diag.shape[0])
187
+ _, h_seq = jax.lax.scan(step, h_0, x_seq)
228
188
  return h_seq
229
189
 
230
190
 
231
- class S4(eqx.Module):
191
+ class DiscreteDiagSSMBlock(DiagSSMBlock):
192
+ delta: Array
193
+
194
+ def __init__(
195
+ self,
196
+ hidden_size: int,
197
+ *,
198
+ key: PRNGKeyArray,
199
+ init_delta: float = 1.0,
200
+ init_scale: float = 10.0,
201
+ ) -> None:
202
+ super().__init__(hidden_size, key=key)
203
+ self.delta = jnp.array(init_delta)
204
+
205
+ # A positive scale helps reduce the gradient at the start.
206
+ self.a_diag = jax.random.uniform(key, (hidden_size,), minval=-1.0, maxval=0.0) * init_scale
207
+
208
+ def get_a_mat(self, x: Array) -> Array:
209
+ """Discretize the diagonal matrix using zero-order hold."""
210
+ a_diag_discrete = jnp.exp(self.a_diag * self.delta)
211
+ return a_diag_discrete
212
+
213
+ def get_b_mat(self, x: Array) -> Array:
214
+ """Discretize the input matrix using zero-order hold."""
215
+ delta_a_diag = self.a_diag * self.delta
216
+ exp_a_diag = jnp.exp(delta_a_diag)
217
+ delta_a_inv = 1 / delta_a_diag
218
+ delta_b_mat = self.delta * self.b_mat
219
+
220
+ b_discrete = delta_a_inv * (exp_a_diag - 1) * delta_b_mat
221
+ return b_discrete
222
+
223
+
224
+ class SSM(eqx.Module):
232
225
  vocab_embedding: eqx.nn.Embedding
233
- proj_in: eqx.nn.Linear
234
- proj_out: eqx.nn.Linear
226
+ output_layer: eqx.nn.Linear
235
227
  blocks: list[BaseSSMBlock]
236
228
  num_layers: int = eqx.static_field()
237
229
  hidden_size: int = eqx.static_field()
@@ -243,24 +235,30 @@ class S4(eqx.Module):
243
235
  hidden_size: int,
244
236
  output_size: int,
245
237
  num_layers: int,
246
- block_type: Literal["ssm", "diag"] = "ssm",
238
+ block_type: Literal["diagonal", "full_rank"] = "full_rank",
247
239
  skip_connections: bool = False,
240
+ discretize: bool = False,
248
241
  *,
249
242
  key: PRNGKeyArray,
250
243
  ) -> None:
251
244
  vocab_key, s4_key = jax.random.split(key, 2)
252
245
  self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
253
- self.proj_in = eqx.nn.Linear(hidden_size, hidden_size, key=key)
254
- self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
246
+ self.output_layer = eqx.nn.Linear(hidden_size, output_size, key=key)
255
247
 
256
248
  block_keys = jax.random.split(s4_key, num_layers)
257
249
 
258
250
  def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
259
251
  match block_type:
260
- case "ssm":
252
+ case "diagonal":
253
+ return (
254
+ DiscreteDiagSSMBlock(hidden_size, key=key, init_delta=0.1)
255
+ if discretize
256
+ else DiagSSMBlock(hidden_size, key=key)
257
+ )
258
+ case "full_rank":
259
+ if discretize:
260
+ raise ValueError("Full rank blocks do not support discretization due to instability.")
261
261
  return SSMBlock(hidden_size, key=key)
262
- case "diag":
263
- return DiagSSMBlock(hidden_size, key=key)
264
262
  case _:
265
263
  raise ValueError(f"Unknown block type: {block_type}")
266
264
 
@@ -276,21 +274,43 @@ class S4(eqx.Module):
276
274
  new_hs.append(h)
277
275
  xh = jax.nn.gelu(h)
278
276
  x = xh + x if self.skip_connections else xh
279
- y = self.proj_out(x)
277
+ y = self.output_layer(x)
280
278
  return new_hs, y
281
279
 
282
280
  def _embed_input(self, x: Array) -> Array:
283
281
  """U is the input to the S4 cell."""
284
- embedded = self.vocab_embedding(x)
285
- return jax.nn.gelu(self.proj_in(embedded))
282
+ return self.vocab_embedding(x)
286
283
 
287
284
  def predict_sequence(self, x_seq: Array) -> Array:
288
285
  x_emb = jax.vmap(self._embed_input)(x_seq)
286
+ for block in self.blocks:
287
+ h = block.forward_sequence(x_emb)
288
+ # h = block.naive_forward_sequence(x_emb)
289
+ h = jax.nn.gelu(h)
290
+ x_emb = h + x_emb if self.skip_connections else h
291
+ y = jax.vmap(self.output_layer)(x_emb)
292
+ return y
293
+
294
+ def generate_sequence(self, prompt_seq: Array, max_len: int) -> Array:
289
295
  hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
296
+ prompt_seq_embedded = jax.vmap(self._embed_input)(prompt_seq)
290
297
 
291
- def step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
298
+ def encode_step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
292
299
  hs, y = self(hs, x)
293
300
  return hs, y
294
301
 
295
- _, y_seq = jax.lax.scan(step, hs, x_emb)
296
- return y_seq
302
+ def decode_step(
303
+ carry: tuple[list[Array], Array, PRNGKeyArray],
304
+ _: None,
305
+ ) -> tuple[tuple[list[Array], Array, PRNGKeyArray], Array]:
306
+ hs, last_token, rng = carry
307
+ token_embedded = self._embed_input(last_token)
308
+ hs, y = self(hs, token_embedded)
309
+ token = jax.random.categorical(rng, y)
310
+ rng = jax.random.split(rng)[0]
311
+ return (hs, token, rng), token
312
+
313
+ hs, _ = jax.lax.scan(encode_step, hs, prompt_seq_embedded)
314
+ _, sequence = jax.lax.scan(decode_step, (hs, prompt_seq[-1], jax.random.PRNGKey(0)), None, length=max_len)
315
+
316
+ return sequence
xax/task/mixins/train.py CHANGED
@@ -218,26 +218,32 @@ class TrainMixin(
218
218
  state = super().on_step_end(state)
219
219
  return state.replace(elapsed_time_s=time.time() - state.start_time_s)
220
220
 
221
- def log_train_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
221
+ def log_train_step(
222
+ self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
223
+ ) -> None:
222
224
  """Override this function to do logging during the training phase.
223
225
 
224
226
  This function is called after the model forward pass and before the
225
227
  backward pass. It is called in the training phase.
226
228
 
227
229
  Args:
230
+ model: The current model.
228
231
  batch: The batch from the dataloader.
229
232
  output: The model output.
230
233
  metrics: The metrics for the current batch.
231
234
  state: The current training state.
232
235
  """
233
236
 
234
- def log_valid_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
237
+ def log_valid_step(
238
+ self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
239
+ ) -> None:
235
240
  """Override this function to do logging during the validation phase.
236
241
 
237
242
  This function is called after the model forward pass. It is called in
238
243
  the validation phase.
239
244
 
240
245
  Args:
246
+ model: The current model.
241
247
  batch: The batch from the dataloader.
242
248
  output: The model output.
243
249
  metrics: The metrics for the current batch.
@@ -251,7 +257,9 @@ class TrainMixin(
251
257
  for k, v in d.items():
252
258
  self.logger.log_scalar(k, v, namespace=ns)
253
259
 
254
- def log_step(self, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State) -> None:
260
+ def log_step(
261
+ self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
262
+ ) -> None:
255
263
  phase = state.phase
256
264
 
257
265
  for k, v in metrics.items():
@@ -265,9 +273,9 @@ class TrainMixin(
265
273
  # Delegate to the appropriate logging function based on the phase.
266
274
  match phase:
267
275
  case "train":
268
- self.log_train_step(batch, output, metrics, state)
276
+ self.log_train_step(model, batch, output, metrics, state)
269
277
  case "valid":
270
- self.log_valid_step(batch, output, metrics, state)
278
+ self.log_valid_step(model, batch, output, metrics, state)
271
279
  case _:
272
280
  raise KeyError(f"Unknown phase: {phase}")
273
281
 
@@ -579,7 +587,7 @@ class TrainMixin(
579
587
  )
580
588
 
581
589
  output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
582
- self.log_step(valid_batch, output, metrics, state)
590
+ self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
583
591
 
584
592
  state = self.on_step_start(state)
585
593
  train_batch = next(train_pf)
@@ -597,7 +605,7 @@ class TrainMixin(
597
605
  batch=train_batch,
598
606
  state=state,
599
607
  )
600
- self.log_step(train_batch, output, metrics, state)
608
+ self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
601
609
 
602
610
  state = self.on_step_end(state)
603
611
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.11
3
+ Version: 0.1.12
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=2JdSxsZphJJFVMGBVXNc0hP2p0FVOu5y7xSgPRNeyNY,13835
1
+ xax/__init__.py,sha256=7vdTYO7jAJdDxKZURlFxc3Y5kr5mVQcTQjeh_sYjD6I,13834
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
@@ -10,11 +10,11 @@ xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
11
  xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
- xax/nn/geom.py,sha256=eK7I8fUHBc3FT7zpm5Yf__bXFQ4LtX6sa17-DxojLTo,3202
13
+ xax/nn/geom.py,sha256=Bj9Z4Y-uoNQuaA_eB_MyG7yImZLuOq8KCLUj1l3daoc,4545
14
14
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
15
  xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
16
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
- xax/nn/ssm.py,sha256=eFeGkV1pkVGc0vNrQbykCbFnlPXQqsqVA_JVzLBHD28,9865
17
+ xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  xax/task/base.py,sha256=E4l1yCrAkM2TVTbVYrmk6BoVHMkbD4IYsTT921XOyi0,7760
20
20
  xax/task/logger.py,sha256=1SZjVC6UCtZUoMPcpp3ckotL324QDeYDvHVhf5MHVqg,36271
@@ -41,7 +41,7 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
41
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=lgLHiHQtnDK0XS3SwHTYZtDv5CTbPRN1-p_K9KiIpHQ,26000
44
+ xax/task/mixins/train.py,sha256=aIebtOIvERYofSyqzNGBpNYlNrXweqFUqM9dHiTx3Dc,26253
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=9WlCrEqbq-SVXPEM4rhsLYERH97XNX7XSYLSI3sgKGk,1619
47
47
  xax/utils/experiments.py,sha256=5CUja1H_cx4dnVqTGQekOpIhqISwHtAgLxZ34GV7cwM,29229
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.1.11.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.1.11.dist-info/METADATA,sha256=qDhn5EGxdiuEe5gQUZiBC430sXhJOPRWboTvsh2onxs,1878
63
- xax-0.1.11.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.1.11.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.1.11.dist-info/RECORD,,
61
+ xax-0.1.12.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.1.12.dist-info/METADATA,sha256=hLRAX5__7QjBgjzhxbRftGvEsNrt8IAdgd22dMtHu_Y,1878
63
+ xax-0.1.12.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.1.12.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.1.12.dist-info/RECORD,,
File without changes