xax 0.1.11__tar.gz → 0.1.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.11/xax.egg-info → xax-0.1.12}/PKG-INFO +1 -1
- {xax-0.1.11 → xax-0.1.12}/xax/__init__.py +9 -10
- {xax-0.1.11 → xax-0.1.12}/xax/nn/geom.py +57 -0
- xax-0.1.12/xax/nn/ssm.py +316 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/train.py +15 -7
- {xax-0.1.11 → xax-0.1.12/xax.egg-info}/PKG-INFO +1 -1
- xax-0.1.11/xax/nn/ssm.py +0 -296
- {xax-0.1.11 → xax-0.1.12}/LICENSE +0 -0
- {xax-0.1.11 → xax-0.1.12}/MANIFEST.in +0 -0
- {xax-0.1.11 → xax-0.1.12}/README.md +0 -0
- {xax-0.1.11 → xax-0.1.12}/pyproject.toml +0 -0
- {xax-0.1.11 → xax-0.1.12}/setup.cfg +0 -0
- {xax-0.1.11 → xax-0.1.12}/setup.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/core/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/core/conf.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/core/state.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/nn/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/nn/embeddings.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/nn/equinox.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/nn/export.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/nn/functions.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/nn/losses.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/nn/norm.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/nn/parallel.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/py.typed +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/requirements-dev.txt +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/requirements.txt +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/base.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/launchers/base.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/logger.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/json.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/state.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/stdout.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/compile.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/process.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/script.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/task/task.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/data/collate.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/debugging.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/experiments.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/jax.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/logging.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/numpy.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/profile.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/pytree.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/tensorboard.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/text.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/types/__init__.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax.egg-info/requires.txt +0 -0
- {xax-0.1.11 → xax-0.1.12}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.12"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -43,15 +43,14 @@ __all__ = [
|
|
43
43
|
"euler_to_quat",
|
44
44
|
"get_projected_gravity_vector_from_quat",
|
45
45
|
"quat_to_euler",
|
46
|
+
"rotate_vector_by_quat",
|
46
47
|
"cross_entropy",
|
47
48
|
"cast_norm_type",
|
48
49
|
"get_norm",
|
49
50
|
"is_master",
|
51
|
+
"BaseSSMBlock",
|
50
52
|
"DiagSSMBlock",
|
51
|
-
"
|
52
|
-
"S4",
|
53
|
-
"S4Layer",
|
54
|
-
"S6Layer",
|
53
|
+
"SSM",
|
55
54
|
"SSMBlock",
|
56
55
|
"BaseLauncher",
|
57
56
|
"CliLauncher",
|
@@ -203,15 +202,14 @@ NAME_MAP: dict[str, str] = {
|
|
203
202
|
"euler_to_quat": "nn.geom",
|
204
203
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
205
204
|
"quat_to_euler": "nn.geom",
|
205
|
+
"rotate_vector_by_quat": "nn.geom",
|
206
206
|
"cross_entropy": "nn.losses",
|
207
207
|
"cast_norm_type": "nn.norm",
|
208
208
|
"get_norm": "nn.norm",
|
209
209
|
"is_master": "nn.parallel",
|
210
|
+
"BaseSSMBlock": "nn.ssm",
|
210
211
|
"DiagSSMBlock": "nn.ssm",
|
211
|
-
"
|
212
|
-
"S4": "nn.ssm",
|
213
|
-
"S4Layer": "nn.ssm",
|
214
|
-
"S6Layer": "nn.ssm",
|
212
|
+
"SSM": "nn.ssm",
|
215
213
|
"SSMBlock": "nn.ssm",
|
216
214
|
"BaseLauncher": "task.launchers.base",
|
217
215
|
"CliLauncher": "task.launchers.cli",
|
@@ -364,11 +362,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
364
362
|
euler_to_quat,
|
365
363
|
get_projected_gravity_vector_from_quat,
|
366
364
|
quat_to_euler,
|
365
|
+
rotate_vector_by_quat,
|
367
366
|
)
|
368
367
|
from xax.nn.losses import cross_entropy
|
369
368
|
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
370
369
|
from xax.nn.parallel import is_master
|
371
|
-
from xax.nn.ssm import
|
370
|
+
from xax.nn.ssm import SSM, BaseSSMBlock, DiagSSMBlock, SSMBlock
|
372
371
|
from xax.task.base import RawConfigType
|
373
372
|
from xax.task.launchers.base import BaseLauncher
|
374
373
|
from xax.task.launchers.cli import CliLauncher
|
@@ -99,3 +99,60 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
|
|
99
99
|
|
100
100
|
# Note: We're rotating [0,0,-1], so we negate gz to match the expected direction
|
101
101
|
return jnp.concatenate([gx, gy, -gz], axis=-1)
|
102
|
+
|
103
|
+
|
104
|
+
def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6) -> jax.Array:
|
105
|
+
"""Rotates a vector by a quaternion.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
vector: The vector to rotate, shape (*, 3).
|
109
|
+
quat: The quaternion to rotate by, shape (*, 4).
|
110
|
+
eps: A small epsilon value to avoid division by zero.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
The rotated vector, shape (*, 3).
|
114
|
+
"""
|
115
|
+
# Normalize quaternion
|
116
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
117
|
+
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
118
|
+
|
119
|
+
# Extract vector components
|
120
|
+
vx, vy, vz = jnp.split(vector, 3, axis=-1)
|
121
|
+
|
122
|
+
# Terms for x component
|
123
|
+
xx = (
|
124
|
+
w * w * vx
|
125
|
+
+ 2 * y * w * vz
|
126
|
+
- 2 * z * w * vy
|
127
|
+
+ x * x * vx
|
128
|
+
+ 2 * y * x * vy
|
129
|
+
+ 2 * z * x * vz
|
130
|
+
- z * z * vx
|
131
|
+
- y * y * vx
|
132
|
+
)
|
133
|
+
|
134
|
+
# Terms for y component
|
135
|
+
yy = (
|
136
|
+
2 * x * y * vx
|
137
|
+
+ y * y * vy
|
138
|
+
+ 2 * z * y * vz
|
139
|
+
+ 2 * w * z * vx
|
140
|
+
- z * z * vy
|
141
|
+
+ w * w * vy
|
142
|
+
- 2 * w * x * vz
|
143
|
+
- x * x * vy
|
144
|
+
)
|
145
|
+
|
146
|
+
# Terms for z component
|
147
|
+
zz = (
|
148
|
+
2 * x * z * vx
|
149
|
+
+ 2 * y * z * vy
|
150
|
+
+ z * z * vz
|
151
|
+
- 2 * w * y * vx
|
152
|
+
+ w * w * vz
|
153
|
+
+ 2 * w * x * vy
|
154
|
+
- y * y * vz
|
155
|
+
- x * x * vz
|
156
|
+
)
|
157
|
+
|
158
|
+
return jnp.concatenate([xx, yy, zz], axis=-1)
|
xax-0.1.12/xax/nn/ssm.py
ADDED
@@ -0,0 +1,316 @@
|
|
1
|
+
"""State space models."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import Literal
|
5
|
+
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jaxtyping import Array, PRNGKeyArray
|
10
|
+
|
11
|
+
|
12
|
+
def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
|
13
|
+
return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
|
14
|
+
|
15
|
+
|
16
|
+
class BaseSSMBlock(eqx.Module, ABC):
|
17
|
+
@abstractmethod
|
18
|
+
def forward(self, h: Array, x: Array) -> Array: ...
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
def forward_sequence(self, x_seq: Array) -> Array: ...
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
def get_a_mat(self, x: Array) -> Array: ...
|
25
|
+
|
26
|
+
@abstractmethod
|
27
|
+
def get_b_mat(self, x: Array) -> Array: ...
|
28
|
+
|
29
|
+
|
30
|
+
class SSMBlock(BaseSSMBlock):
|
31
|
+
a_mat: Array
|
32
|
+
b_mat: Array
|
33
|
+
|
34
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
35
|
+
key_a, key_b = jax.random.split(key)
|
36
|
+
self.a_mat = glorot(key_a, (hidden_size, hidden_size))
|
37
|
+
self.b_mat = glorot(key_b, (hidden_size, hidden_size))
|
38
|
+
|
39
|
+
def get_a_mat(self, x: Array) -> Array:
|
40
|
+
return self.a_mat
|
41
|
+
|
42
|
+
def get_b_mat(self, x: Array) -> Array:
|
43
|
+
return self.b_mat
|
44
|
+
|
45
|
+
def forward(self, h: Array, x: Array) -> Array:
|
46
|
+
"""Perform a forward pass.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
h: Hidden state of shape (H,).
|
50
|
+
x: Input of shape (H,).
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
Hidden state of shape (H,).
|
54
|
+
"""
|
55
|
+
a_mat = self.get_a_mat(x)
|
56
|
+
b_mat = self.get_b_mat(x)
|
57
|
+
h = a_mat @ h + b_mat.T @ x
|
58
|
+
return h
|
59
|
+
|
60
|
+
def forward_sequence(self, x_seq: Array) -> Array:
|
61
|
+
"""Perform a forward pass across time.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
x_seq: Input sequence of shape (T, H).
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Hidden state sequence of shape (T, H).
|
68
|
+
"""
|
69
|
+
|
70
|
+
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
71
|
+
h = self.forward(h, x)
|
72
|
+
return h, h
|
73
|
+
|
74
|
+
a_mat = self.get_a_mat(x_seq)
|
75
|
+
h_0 = jnp.zeros(a_mat.shape[0])
|
76
|
+
_, h_seq = jax.lax.scan(step, h_0, x_seq)
|
77
|
+
return h_seq
|
78
|
+
|
79
|
+
|
80
|
+
class DiagSSMBlock(BaseSSMBlock):
|
81
|
+
a_diag: Array
|
82
|
+
b_mat: Array
|
83
|
+
|
84
|
+
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
85
|
+
keys = jax.random.split(key, 2)
|
86
|
+
self.a_diag = glorot(keys[0], (hidden_size,))
|
87
|
+
self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
|
88
|
+
|
89
|
+
def get_a_mat(self, x: Array) -> Array:
|
90
|
+
return self.a_diag
|
91
|
+
|
92
|
+
def get_b_mat(self, x: Array) -> Array:
|
93
|
+
return self.b_mat
|
94
|
+
|
95
|
+
def forward(self, h: Array, x: Array) -> Array:
|
96
|
+
"""Perform a forward pass.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
h: Hidden state of shape (H,).
|
100
|
+
x: Input of shape (H,).
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Hidden state of shape (H,).
|
104
|
+
"""
|
105
|
+
a_diag = self.get_a_mat(x)
|
106
|
+
b_mat = self.get_b_mat(x)
|
107
|
+
h = a_diag * h + b_mat.T @ x
|
108
|
+
return h
|
109
|
+
|
110
|
+
def forward_sequence(self, x_seq: Array, *, use_conv: bool = True, recursive_kernel_calc: bool = False) -> Array:
|
111
|
+
"""Perform a potentially parallelized forward pass across time.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
x_seq: Input sequence of shape (T, H).
|
115
|
+
use_conv: Whether to use convolution to compute the sequence.
|
116
|
+
recursive_kernel_calc: Whether to use a recursive kernel calculation.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
Hidden state sequence of shape (T, H).
|
120
|
+
"""
|
121
|
+
if use_conv:
|
122
|
+
return self._forward_sequence_conv(x_seq, recursive_kernel_calc=recursive_kernel_calc)
|
123
|
+
else:
|
124
|
+
return self._forward_sequence_scan(x_seq)
|
125
|
+
|
126
|
+
def _get_kernel(self, x_seq: Array, length: int) -> Array:
|
127
|
+
"""Returns the kernel with time as the final dimension."""
|
128
|
+
exponents = jnp.arange(length)
|
129
|
+
a_diag = self.get_a_mat(x_seq)
|
130
|
+
kernel = jnp.power(a_diag[:, None], exponents) # (H, T)
|
131
|
+
kernel = kernel[:, None, :] # (H, 1, T)
|
132
|
+
return kernel
|
133
|
+
|
134
|
+
def _get_kernel_recursive(self, x_seq: Array, length: int) -> Array:
|
135
|
+
"""Returns the kernel with time as the final dimension."""
|
136
|
+
assert length % 2 == 0, "Length must be even."
|
137
|
+
a_diag = self.get_a_mat(x_seq)
|
138
|
+
|
139
|
+
def helper(length: int) -> tuple[Array, Array]:
|
140
|
+
"""Returns the kernel and the sqrt of the diagonal."""
|
141
|
+
if length == 1:
|
142
|
+
return jnp.ones_like(a_diag)[:, None], a_diag[:, None]
|
143
|
+
|
144
|
+
half_length = length // 2
|
145
|
+
kernel_half, a_half = helper(half_length)
|
146
|
+
kernel = jnp.concatenate([kernel_half, a_half * kernel_half], axis=-1)
|
147
|
+
return kernel, a_half * a_half
|
148
|
+
|
149
|
+
kernel, a_diag = helper(length)
|
150
|
+
return kernel[:, None, :] # (H, 1, L)
|
151
|
+
|
152
|
+
def _forward_sequence_conv(self, x_seq: Array, *, recursive_kernel_calc: bool = False) -> Array:
|
153
|
+
"""Convolves x (T, H) across time using the kernel."""
|
154
|
+
seq_len, hidden_size = x_seq.shape
|
155
|
+
b_mat = self.get_b_mat(x_seq)
|
156
|
+
|
157
|
+
s = b_mat.T @ x_seq.T # (H, T)
|
158
|
+
s_padded = jnp.pad(s, ((0, 0), (seq_len - 1, 0)))[None, :, :] # (1, H, 2T-1)
|
159
|
+
|
160
|
+
if recursive_kernel_calc:
|
161
|
+
kernel = self._get_kernel_recursive(x_seq, seq_len)
|
162
|
+
else:
|
163
|
+
kernel = self._get_kernel(x_seq, seq_len)
|
164
|
+
|
165
|
+
kernel_flipped = jnp.flip(kernel, axis=-1) # (H, 1, L)
|
166
|
+
|
167
|
+
conv_out = jax.lax.conv_general_dilated(
|
168
|
+
s_padded,
|
169
|
+
kernel_flipped,
|
170
|
+
window_strides=(1,),
|
171
|
+
padding="VALID",
|
172
|
+
dimension_numbers=("NCT", "OIT", "NCT"), # convolving over time
|
173
|
+
feature_group_count=hidden_size,
|
174
|
+
)
|
175
|
+
conv_out = conv_out[0].T # (T, H)
|
176
|
+
return conv_out
|
177
|
+
|
178
|
+
def _forward_sequence_scan(self, x_seq: Array) -> Array:
|
179
|
+
"""Naively forward across time."""
|
180
|
+
|
181
|
+
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
182
|
+
h = self.forward(h, x)
|
183
|
+
return h, h
|
184
|
+
|
185
|
+
a_diag = self.get_a_mat(x_seq)
|
186
|
+
h_0 = jnp.zeros(a_diag.shape[0])
|
187
|
+
_, h_seq = jax.lax.scan(step, h_0, x_seq)
|
188
|
+
return h_seq
|
189
|
+
|
190
|
+
|
191
|
+
class DiscreteDiagSSMBlock(DiagSSMBlock):
|
192
|
+
delta: Array
|
193
|
+
|
194
|
+
def __init__(
|
195
|
+
self,
|
196
|
+
hidden_size: int,
|
197
|
+
*,
|
198
|
+
key: PRNGKeyArray,
|
199
|
+
init_delta: float = 1.0,
|
200
|
+
init_scale: float = 10.0,
|
201
|
+
) -> None:
|
202
|
+
super().__init__(hidden_size, key=key)
|
203
|
+
self.delta = jnp.array(init_delta)
|
204
|
+
|
205
|
+
# A positive scale helps reduce the gradient at the start.
|
206
|
+
self.a_diag = jax.random.uniform(key, (hidden_size,), minval=-1.0, maxval=0.0) * init_scale
|
207
|
+
|
208
|
+
def get_a_mat(self, x: Array) -> Array:
|
209
|
+
"""Discretize the diagonal matrix using zero-order hold."""
|
210
|
+
a_diag_discrete = jnp.exp(self.a_diag * self.delta)
|
211
|
+
return a_diag_discrete
|
212
|
+
|
213
|
+
def get_b_mat(self, x: Array) -> Array:
|
214
|
+
"""Discretize the input matrix using zero-order hold."""
|
215
|
+
delta_a_diag = self.a_diag * self.delta
|
216
|
+
exp_a_diag = jnp.exp(delta_a_diag)
|
217
|
+
delta_a_inv = 1 / delta_a_diag
|
218
|
+
delta_b_mat = self.delta * self.b_mat
|
219
|
+
|
220
|
+
b_discrete = delta_a_inv * (exp_a_diag - 1) * delta_b_mat
|
221
|
+
return b_discrete
|
222
|
+
|
223
|
+
|
224
|
+
class SSM(eqx.Module):
|
225
|
+
vocab_embedding: eqx.nn.Embedding
|
226
|
+
output_layer: eqx.nn.Linear
|
227
|
+
blocks: list[BaseSSMBlock]
|
228
|
+
num_layers: int = eqx.static_field()
|
229
|
+
hidden_size: int = eqx.static_field()
|
230
|
+
skip_connections: bool = eqx.static_field()
|
231
|
+
|
232
|
+
def __init__(
|
233
|
+
self,
|
234
|
+
input_size: int,
|
235
|
+
hidden_size: int,
|
236
|
+
output_size: int,
|
237
|
+
num_layers: int,
|
238
|
+
block_type: Literal["diagonal", "full_rank"] = "full_rank",
|
239
|
+
skip_connections: bool = False,
|
240
|
+
discretize: bool = False,
|
241
|
+
*,
|
242
|
+
key: PRNGKeyArray,
|
243
|
+
) -> None:
|
244
|
+
vocab_key, s4_key = jax.random.split(key, 2)
|
245
|
+
self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
|
246
|
+
self.output_layer = eqx.nn.Linear(hidden_size, output_size, key=key)
|
247
|
+
|
248
|
+
block_keys = jax.random.split(s4_key, num_layers)
|
249
|
+
|
250
|
+
def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
|
251
|
+
match block_type:
|
252
|
+
case "diagonal":
|
253
|
+
return (
|
254
|
+
DiscreteDiagSSMBlock(hidden_size, key=key, init_delta=0.1)
|
255
|
+
if discretize
|
256
|
+
else DiagSSMBlock(hidden_size, key=key)
|
257
|
+
)
|
258
|
+
case "full_rank":
|
259
|
+
if discretize:
|
260
|
+
raise ValueError("Full rank blocks do not support discretization due to instability.")
|
261
|
+
return SSMBlock(hidden_size, key=key)
|
262
|
+
case _:
|
263
|
+
raise ValueError(f"Unknown block type: {block_type}")
|
264
|
+
|
265
|
+
self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
|
266
|
+
self.skip_connections = skip_connections
|
267
|
+
self.num_layers = num_layers
|
268
|
+
self.hidden_size = hidden_size
|
269
|
+
|
270
|
+
def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
271
|
+
new_hs = []
|
272
|
+
for i, block in enumerate(self.blocks):
|
273
|
+
h = block.forward(hs[i], x)
|
274
|
+
new_hs.append(h)
|
275
|
+
xh = jax.nn.gelu(h)
|
276
|
+
x = xh + x if self.skip_connections else xh
|
277
|
+
y = self.output_layer(x)
|
278
|
+
return new_hs, y
|
279
|
+
|
280
|
+
def _embed_input(self, x: Array) -> Array:
|
281
|
+
"""U is the input to the S4 cell."""
|
282
|
+
return self.vocab_embedding(x)
|
283
|
+
|
284
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
285
|
+
x_emb = jax.vmap(self._embed_input)(x_seq)
|
286
|
+
for block in self.blocks:
|
287
|
+
h = block.forward_sequence(x_emb)
|
288
|
+
# h = block.naive_forward_sequence(x_emb)
|
289
|
+
h = jax.nn.gelu(h)
|
290
|
+
x_emb = h + x_emb if self.skip_connections else h
|
291
|
+
y = jax.vmap(self.output_layer)(x_emb)
|
292
|
+
return y
|
293
|
+
|
294
|
+
def generate_sequence(self, prompt_seq: Array, max_len: int) -> Array:
|
295
|
+
hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
|
296
|
+
prompt_seq_embedded = jax.vmap(self._embed_input)(prompt_seq)
|
297
|
+
|
298
|
+
def encode_step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
299
|
+
hs, y = self(hs, x)
|
300
|
+
return hs, y
|
301
|
+
|
302
|
+
def decode_step(
|
303
|
+
carry: tuple[list[Array], Array, PRNGKeyArray],
|
304
|
+
_: None,
|
305
|
+
) -> tuple[tuple[list[Array], Array, PRNGKeyArray], Array]:
|
306
|
+
hs, last_token, rng = carry
|
307
|
+
token_embedded = self._embed_input(last_token)
|
308
|
+
hs, y = self(hs, token_embedded)
|
309
|
+
token = jax.random.categorical(rng, y)
|
310
|
+
rng = jax.random.split(rng)[0]
|
311
|
+
return (hs, token, rng), token
|
312
|
+
|
313
|
+
hs, _ = jax.lax.scan(encode_step, hs, prompt_seq_embedded)
|
314
|
+
_, sequence = jax.lax.scan(decode_step, (hs, prompt_seq[-1], jax.random.PRNGKey(0)), None, length=max_len)
|
315
|
+
|
316
|
+
return sequence
|
@@ -218,26 +218,32 @@ class TrainMixin(
|
|
218
218
|
state = super().on_step_end(state)
|
219
219
|
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
220
220
|
|
221
|
-
def log_train_step(
|
221
|
+
def log_train_step(
|
222
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
223
|
+
) -> None:
|
222
224
|
"""Override this function to do logging during the training phase.
|
223
225
|
|
224
226
|
This function is called after the model forward pass and before the
|
225
227
|
backward pass. It is called in the training phase.
|
226
228
|
|
227
229
|
Args:
|
230
|
+
model: The current model.
|
228
231
|
batch: The batch from the dataloader.
|
229
232
|
output: The model output.
|
230
233
|
metrics: The metrics for the current batch.
|
231
234
|
state: The current training state.
|
232
235
|
"""
|
233
236
|
|
234
|
-
def log_valid_step(
|
237
|
+
def log_valid_step(
|
238
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
239
|
+
) -> None:
|
235
240
|
"""Override this function to do logging during the validation phase.
|
236
241
|
|
237
242
|
This function is called after the model forward pass. It is called in
|
238
243
|
the validation phase.
|
239
244
|
|
240
245
|
Args:
|
246
|
+
model: The current model.
|
241
247
|
batch: The batch from the dataloader.
|
242
248
|
output: The model output.
|
243
249
|
metrics: The metrics for the current batch.
|
@@ -251,7 +257,9 @@ class TrainMixin(
|
|
251
257
|
for k, v in d.items():
|
252
258
|
self.logger.log_scalar(k, v, namespace=ns)
|
253
259
|
|
254
|
-
def log_step(
|
260
|
+
def log_step(
|
261
|
+
self, model: PyTree, batch: Batch, output: Output, metrics: FrozenDict[str, Array], state: State
|
262
|
+
) -> None:
|
255
263
|
phase = state.phase
|
256
264
|
|
257
265
|
for k, v in metrics.items():
|
@@ -265,9 +273,9 @@ class TrainMixin(
|
|
265
273
|
# Delegate to the appropriate logging function based on the phase.
|
266
274
|
match phase:
|
267
275
|
case "train":
|
268
|
-
self.log_train_step(batch, output, metrics, state)
|
276
|
+
self.log_train_step(model, batch, output, metrics, state)
|
269
277
|
case "valid":
|
270
|
-
self.log_valid_step(batch, output, metrics, state)
|
278
|
+
self.log_valid_step(model, batch, output, metrics, state)
|
271
279
|
case _:
|
272
280
|
raise KeyError(f"Unknown phase: {phase}")
|
273
281
|
|
@@ -579,7 +587,7 @@ class TrainMixin(
|
|
579
587
|
)
|
580
588
|
|
581
589
|
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
582
|
-
self.log_step(valid_batch, output, metrics, state)
|
590
|
+
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
583
591
|
|
584
592
|
state = self.on_step_start(state)
|
585
593
|
train_batch = next(train_pf)
|
@@ -597,7 +605,7 @@ class TrainMixin(
|
|
597
605
|
batch=train_batch,
|
598
606
|
state=state,
|
599
607
|
)
|
600
|
-
self.log_step(train_batch, output, metrics, state)
|
608
|
+
self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
|
601
609
|
|
602
610
|
state = self.on_step_end(state)
|
603
611
|
|
xax-0.1.11/xax/nn/ssm.py
DELETED
@@ -1,296 +0,0 @@
|
|
1
|
-
"""State space models."""
|
2
|
-
|
3
|
-
from abc import ABC, abstractmethod
|
4
|
-
from typing import Literal
|
5
|
-
|
6
|
-
import equinox as eqx
|
7
|
-
import jax
|
8
|
-
import jax.numpy as jnp
|
9
|
-
from jaxtyping import Array, PRNGKeyArray
|
10
|
-
|
11
|
-
|
12
|
-
def glorot(key: PRNGKeyArray, shape: tuple[int, ...]) -> Array:
|
13
|
-
return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) * jnp.sqrt(2 / sum(shape))
|
14
|
-
|
15
|
-
|
16
|
-
class DiscreteTimeS4(eqx.Module):
|
17
|
-
a: Array
|
18
|
-
B: Array
|
19
|
-
C: Array
|
20
|
-
proj_in: eqx.nn.Linear
|
21
|
-
proj_out: eqx.nn.Linear
|
22
|
-
|
23
|
-
def __init__(
|
24
|
-
self,
|
25
|
-
hidden_size: int,
|
26
|
-
projection_size: int,
|
27
|
-
input_size: int,
|
28
|
-
output_size: int,
|
29
|
-
*,
|
30
|
-
key: PRNGKeyArray,
|
31
|
-
) -> None:
|
32
|
-
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
33
|
-
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
34
|
-
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
35
|
-
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
36
|
-
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
37
|
-
|
38
|
-
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
39
|
-
h = self.a * h + self.B.T @ x
|
40
|
-
y = self.C.T @ h
|
41
|
-
return h, y
|
42
|
-
|
43
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
44
|
-
x_proj = jax.vmap(lambda x: jax.nn.relu(self.proj_in(x)))(x_seq)
|
45
|
-
h = jnp.zeros(self.a.shape[0])
|
46
|
-
|
47
|
-
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
48
|
-
h = self.a * h + self.B.T @ x
|
49
|
-
y = self.C.T @ h
|
50
|
-
return h, y
|
51
|
-
|
52
|
-
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
53
|
-
y_out = jax.vmap(self.proj_out)(y_seq)
|
54
|
-
return y_out
|
55
|
-
|
56
|
-
|
57
|
-
class S4Layer(eqx.Module):
|
58
|
-
a: Array
|
59
|
-
B: Array
|
60
|
-
C: Array
|
61
|
-
proj_in: eqx.nn.Linear
|
62
|
-
proj_out: eqx.nn.Linear
|
63
|
-
delta: Array
|
64
|
-
|
65
|
-
def __init__(
|
66
|
-
self,
|
67
|
-
hidden_size: int,
|
68
|
-
projection_size: int,
|
69
|
-
input_size: int,
|
70
|
-
output_size: int,
|
71
|
-
*,
|
72
|
-
key: PRNGKeyArray,
|
73
|
-
) -> None:
|
74
|
-
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
75
|
-
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
76
|
-
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
77
|
-
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
78
|
-
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
79
|
-
self.delta = jax.random.uniform(key, (hidden_size,))
|
80
|
-
|
81
|
-
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
82
|
-
delta_a = self.delta * self.a
|
83
|
-
a_bar = jnp.exp(delta_a)
|
84
|
-
b_bar = jnp.linalg.inv(delta_a) * (a_bar - 1) @ (self.delta * self.B)
|
85
|
-
h = a_bar * h + b_bar.T @ x
|
86
|
-
y = self.C.T @ h
|
87
|
-
return h, y
|
88
|
-
|
89
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
90
|
-
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
91
|
-
h = jnp.zeros(self.a.shape[0])
|
92
|
-
|
93
|
-
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
94
|
-
h = self.a * h + self.B.T @ x
|
95
|
-
y = self.C.T @ h
|
96
|
-
return h, y
|
97
|
-
|
98
|
-
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
99
|
-
y_out = jax.vmap(self.proj_out)(y_seq)
|
100
|
-
return y_out
|
101
|
-
|
102
|
-
|
103
|
-
class S6Layer(eqx.Module):
|
104
|
-
a: Array
|
105
|
-
B: Array
|
106
|
-
C: Array
|
107
|
-
proj_in: eqx.nn.Linear
|
108
|
-
proj_out: eqx.nn.Linear
|
109
|
-
delta: Array
|
110
|
-
|
111
|
-
def __init__(
|
112
|
-
self,
|
113
|
-
hidden_size: int,
|
114
|
-
projection_size: int,
|
115
|
-
input_size: int,
|
116
|
-
output_size: int,
|
117
|
-
*,
|
118
|
-
key: PRNGKeyArray,
|
119
|
-
) -> None:
|
120
|
-
self.a = jax.nn.initializers.glorot_uniform()(key, (hidden_size,))
|
121
|
-
self.B = jax.nn.initializers.glorot_uniform()(key, (projection_size, hidden_size))
|
122
|
-
self.C = jax.nn.initializers.glorot_uniform()(key, (hidden_size, projection_size))
|
123
|
-
self.proj_in = eqx.nn.Linear(input_size, projection_size, key=key)
|
124
|
-
self.proj_out = eqx.nn.Linear(projection_size, output_size, key=key)
|
125
|
-
self.delta = jax.random.uniform(key, (hidden_size,))
|
126
|
-
|
127
|
-
def __call__(self, h: Array, x: Array) -> tuple[Array, Array]:
|
128
|
-
h = self.a * h + self.B.T @ x
|
129
|
-
y = self.C.T @ h
|
130
|
-
return h, y
|
131
|
-
|
132
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
133
|
-
x_proj = jax.vmap(lambda x: jax.nn.gelu(self.proj_in(x)))(x_seq)
|
134
|
-
h = jnp.zeros(self.a.shape[0])
|
135
|
-
|
136
|
-
def scan_fn(h: Array, x: Array) -> tuple[Array, Array]:
|
137
|
-
h = self.a * h + self.B.T @ x
|
138
|
-
y = self.C.T @ h
|
139
|
-
return h, y
|
140
|
-
|
141
|
-
_, y_seq = jax.lax.scan(scan_fn, h, x_proj)
|
142
|
-
y_out = jax.vmap(self.proj_out)(y_seq)
|
143
|
-
return y_out
|
144
|
-
|
145
|
-
|
146
|
-
class BaseSSMBlock(eqx.Module, ABC):
|
147
|
-
@abstractmethod
|
148
|
-
def forward(self, h: Array, x: Array) -> Array:
|
149
|
-
pass
|
150
|
-
|
151
|
-
|
152
|
-
class SSMBlock(BaseSSMBlock):
|
153
|
-
a_mat: Array
|
154
|
-
b_mat: Array
|
155
|
-
|
156
|
-
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
157
|
-
key_a, key_b = jax.random.split(key)
|
158
|
-
self.a_mat = glorot(key_a, (hidden_size, hidden_size))
|
159
|
-
self.b_mat = glorot(key_b, (hidden_size, hidden_size))
|
160
|
-
|
161
|
-
def forward(self, h: Array, x: Array) -> Array:
|
162
|
-
h = self.a_mat @ h + self.b_mat.T @ x
|
163
|
-
return h
|
164
|
-
|
165
|
-
def get_kernel(self, length: int) -> Array:
|
166
|
-
return self.a_mat
|
167
|
-
|
168
|
-
|
169
|
-
class DiagSSMBlock(BaseSSMBlock):
|
170
|
-
a_mat: Array
|
171
|
-
b_mat: Array
|
172
|
-
|
173
|
-
def __init__(self, hidden_size: int, *, key: PRNGKeyArray) -> None:
|
174
|
-
keys = jax.random.split(key, 2)
|
175
|
-
self.a_mat = glorot(keys[0], (hidden_size,))
|
176
|
-
self.b_mat = glorot(keys[1], (hidden_size, hidden_size))
|
177
|
-
|
178
|
-
def forward(self, h: Array, x: Array) -> Array:
|
179
|
-
h = self.a_mat * h + self.b_mat.T @ x
|
180
|
-
h = jax.nn.tanh(h)
|
181
|
-
return h
|
182
|
-
|
183
|
-
def get_kernel(self, length: int) -> Array:
|
184
|
-
"""Returns the kernel with time as the final dimension."""
|
185
|
-
exponents = jnp.arange(length)
|
186
|
-
kernel = jnp.power(self.a_mat[:, None], exponents) # (H, L)
|
187
|
-
kernel = kernel[:, None, :] # (H, 1, L)
|
188
|
-
return kernel
|
189
|
-
|
190
|
-
def forward_across_time(self, x: Array) -> Array:
|
191
|
-
"""Convolves x (T, H) across time using the kernel."""
|
192
|
-
tsz, nhid = x.shape
|
193
|
-
|
194
|
-
# Compute s = x @ U.T + b, with shape (N, T, H)
|
195
|
-
s = self.b_mat.T @ x
|
196
|
-
s = s.T # (H, T)
|
197
|
-
|
198
|
-
kernel = self.get_kernel(tsz) # (H, 1, T)
|
199
|
-
kernel_flipped = jnp.flip(kernel, axis=-1)
|
200
|
-
|
201
|
-
# Pad s on the left along the time axis (pad length T-1)
|
202
|
-
s_padded = jnp.pad(s, ((0, 0), (0, 0), (tsz - 1, 0)))
|
203
|
-
|
204
|
-
# Perform depthwise (grouped) 1D convolution.
|
205
|
-
# We use input shape (N, H, L) and kernel shape (H, 1, T) with feature_group_count=H.
|
206
|
-
# The dimension_numbers are chosen so that the channel dimension is second.
|
207
|
-
conv_out = jax.lax.conv_general_dilated(
|
208
|
-
s_padded,
|
209
|
-
kernel_flipped,
|
210
|
-
window_strides=(1,),
|
211
|
-
padding="VALID",
|
212
|
-
dimension_numbers=("NCH", "OIH", "NCH"),
|
213
|
-
feature_group_count=nhid,
|
214
|
-
)
|
215
|
-
# conv_out has shape (N, H, T); transpose to (N, T, H)
|
216
|
-
conv_out = jnp.transpose(conv_out, (0, 2, 1))
|
217
|
-
return conv_out
|
218
|
-
|
219
|
-
def naive_forward_accross_time(self, x: Array) -> Array:
|
220
|
-
"""Naively forward across time."""
|
221
|
-
|
222
|
-
def step(h: Array, x: Array) -> tuple[Array, Array]:
|
223
|
-
h = self.forward(h, x)
|
224
|
-
return h, h
|
225
|
-
|
226
|
-
h_0 = jnp.zeros(self.a_mat.shape[0])
|
227
|
-
_, h_seq = jax.lax.scan(step, h_0, x)
|
228
|
-
return h_seq
|
229
|
-
|
230
|
-
|
231
|
-
class S4(eqx.Module):
|
232
|
-
vocab_embedding: eqx.nn.Embedding
|
233
|
-
proj_in: eqx.nn.Linear
|
234
|
-
proj_out: eqx.nn.Linear
|
235
|
-
blocks: list[BaseSSMBlock]
|
236
|
-
num_layers: int = eqx.static_field()
|
237
|
-
hidden_size: int = eqx.static_field()
|
238
|
-
skip_connections: bool = eqx.static_field()
|
239
|
-
|
240
|
-
def __init__(
|
241
|
-
self,
|
242
|
-
input_size: int,
|
243
|
-
hidden_size: int,
|
244
|
-
output_size: int,
|
245
|
-
num_layers: int,
|
246
|
-
block_type: Literal["ssm", "diag"] = "ssm",
|
247
|
-
skip_connections: bool = False,
|
248
|
-
*,
|
249
|
-
key: PRNGKeyArray,
|
250
|
-
) -> None:
|
251
|
-
vocab_key, s4_key = jax.random.split(key, 2)
|
252
|
-
self.vocab_embedding = eqx.nn.Embedding(input_size, hidden_size, key=vocab_key)
|
253
|
-
self.proj_in = eqx.nn.Linear(hidden_size, hidden_size, key=key)
|
254
|
-
self.proj_out = eqx.nn.Linear(hidden_size, output_size, key=key)
|
255
|
-
|
256
|
-
block_keys = jax.random.split(s4_key, num_layers)
|
257
|
-
|
258
|
-
def get_block(key: PRNGKeyArray) -> BaseSSMBlock:
|
259
|
-
match block_type:
|
260
|
-
case "ssm":
|
261
|
-
return SSMBlock(hidden_size, key=key)
|
262
|
-
case "diag":
|
263
|
-
return DiagSSMBlock(hidden_size, key=key)
|
264
|
-
case _:
|
265
|
-
raise ValueError(f"Unknown block type: {block_type}")
|
266
|
-
|
267
|
-
self.blocks = [get_block(block_keys[i]) for i in range(num_layers)]
|
268
|
-
self.skip_connections = skip_connections
|
269
|
-
self.num_layers = num_layers
|
270
|
-
self.hidden_size = hidden_size
|
271
|
-
|
272
|
-
def __call__(self, hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
273
|
-
new_hs = []
|
274
|
-
for i, block in enumerate(self.blocks):
|
275
|
-
h = block.forward(hs[i], x)
|
276
|
-
new_hs.append(h)
|
277
|
-
xh = jax.nn.gelu(h)
|
278
|
-
x = xh + x if self.skip_connections else xh
|
279
|
-
y = self.proj_out(x)
|
280
|
-
return new_hs, y
|
281
|
-
|
282
|
-
def _embed_input(self, x: Array) -> Array:
|
283
|
-
"""U is the input to the S4 cell."""
|
284
|
-
embedded = self.vocab_embedding(x)
|
285
|
-
return jax.nn.gelu(self.proj_in(embedded))
|
286
|
-
|
287
|
-
def predict_sequence(self, x_seq: Array) -> Array:
|
288
|
-
x_emb = jax.vmap(self._embed_input)(x_seq)
|
289
|
-
hs = [jnp.zeros(self.hidden_size) for _ in range(self.num_layers)]
|
290
|
-
|
291
|
-
def step(hs: list[Array], x: Array) -> tuple[list[Array], Array]:
|
292
|
-
hs, y = self(hs, x)
|
293
|
-
return hs, y
|
294
|
-
|
295
|
-
_, y_seq = jax.lax.scan(step, hs, x_emb)
|
296
|
-
return y_seq
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|